Compare commits
11 Commits
user/rcade
...
Cadene-pat
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b4aef34c8e | ||
|
|
f98200297d | ||
|
|
86bbd16d43 | ||
|
|
0f6e0f6d74 | ||
|
|
fc3e545e03 | ||
|
|
b98ea415c1 | ||
|
|
bbe9057225 | ||
|
|
8c4643687c | ||
|
|
fab037f78d | ||
|
|
03d647269e | ||
|
|
2252b42337 |
12
.github/workflows/build-docker-images.yml
vendored
@@ -17,6 +17,12 @@ jobs:
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
steps:
|
||||
- name: Install Git LFS
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
@@ -46,6 +52,12 @@ jobs:
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
steps:
|
||||
- name: Install Git LFS
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
|
||||
1
.gitignore
vendored
@@ -121,6 +121,7 @@ celerybeat.pid
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
21
README.md
@@ -22,8 +22,21 @@
|
||||
|
||||
</div>
|
||||
|
||||
<h2 align="center">
|
||||
<p><a href="https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md">Hot new tutorial: Getting started with real-world robots</a></p>
|
||||
</h2>
|
||||
|
||||
<div align="center">
|
||||
<img src="media/tutorial/koch_v1_1_leader_follower.webp?raw=true" alt="Koch v1.1 leader and follower arms" title="Koch v1.1 leader and follower arms" width="50%">
|
||||
<p>We just dropped an in-depth tutorial on how to build your own robot!</p>
|
||||
<p>Teach it new skills by showing it a few moves with just a laptop.</p>
|
||||
<p>Then watch your homemade robot act autonomously 🤯</p>
|
||||
<p>For more info, see <a href="https://x.com/RemiCadene/status/1825455895561859185">our thread on X</a> or <a href="https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md">our tutorial page</a>.</p>
|
||||
</div>
|
||||
|
||||
|
||||
<h3 align="center">
|
||||
<p>State-of-the-art Machine Learning for real-world robotics</p>
|
||||
<p>State-of-the-art AI for real-world robotics</p>
|
||||
</h3>
|
||||
|
||||
---
|
||||
@@ -77,7 +90,7 @@ conda activate lerobot
|
||||
|
||||
Install 🤗 LeRobot:
|
||||
```bash
|
||||
pip install .
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
> **NOTE:** Depending on your platform, If you encounter any build errors during this step
|
||||
@@ -91,7 +104,7 @@ For simulations, 🤗 LeRobot comes with gymnasium environments that can be inst
|
||||
|
||||
For instance, to install 🤗 LeRobot with aloha and pusht, use:
|
||||
```bash
|
||||
pip install ".[aloha, pusht]"
|
||||
pip install -e ".[aloha, pusht]"
|
||||
```
|
||||
|
||||
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with
|
||||
@@ -116,10 +129,12 @@ wandb login
|
||||
| | ├── datasets # various datasets of human demonstrations: aloha, pusht, xarm
|
||||
| | ├── envs # various sim environments: aloha, pusht, xarm
|
||||
| | ├── policies # various policies: act, diffusion, tdmpc
|
||||
| | ├── robot_devices # various real devices: dynamixel motors, opencv cameras, koch robots
|
||||
| | └── utils # various utilities
|
||||
| └── scripts # contains functions to execute via command line
|
||||
| ├── eval.py # load policy and evaluate it on an environment
|
||||
| ├── train.py # train a policy via imitation learning and/or reinforcement learning
|
||||
| ├── control_robot.py # teleoperate a real robot, record data, run a policy
|
||||
| ├── push_dataset_to_hub.py # convert your dataset into LeRobot dataset format and upload it to the Hugging Face hub
|
||||
| └── visualize_dataset.py # load a dataset and render its demonstrations
|
||||
├── outputs # contains results of scripts execution: logs, videos, model checkpoints
|
||||
|
||||
@@ -9,6 +9,7 @@ ARG DEBIAN_FRONTEND=noninteractive
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
||||
speech-dispatcher \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create virtual environment
|
||||
|
||||
@@ -13,6 +13,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
sed gawk grep curl wget zip unzip \
|
||||
tcpdump sysstat screen tmux \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
|
||||
speech-dispatcher \
|
||||
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ ARG DEBIAN_FRONTEND=noninteractive
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential cmake \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa ffmpeg \
|
||||
speech-dispatcher \
|
||||
python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
@@ -18,8 +18,6 @@ from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
output_directory = Path("outputs/eval/example_pusht_diffusion")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
device = torch.device("cuda")
|
||||
|
||||
# Download the diffusion policy for pusht environment
|
||||
pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))
|
||||
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
|
||||
@@ -27,6 +25,17 @@ pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))
|
||||
|
||||
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
|
||||
policy.eval()
|
||||
|
||||
# Check if GPU is available
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
print("GPU is available. Device set to:", device)
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
print(f"GPU is not available. Device set to: {device}. Inference will be slower than on GPU.")
|
||||
# Decrease the number of reverse-diffusion steps (trades off a bit of quality for 10x speed)
|
||||
policy.diffusion.num_inference_steps = 10
|
||||
|
||||
policy.to(device)
|
||||
|
||||
# Initialize evaluation environment to render two observation types:
|
||||
|
||||
1005
examples/7_get_started_with_real_robot.md
Normal file
@@ -44,7 +44,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
root: Path | None = None,
|
||||
root: Path | None = DATA_DIR,
|
||||
split: str = "train",
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
@@ -53,24 +53,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.root = root
|
||||
if self.root is None and DATA_DIR is not None:
|
||||
self.root = DATA_DIR
|
||||
self.split = split
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
# load data from hub or locally when root is provided
|
||||
# TODO(rcadene, aliberts): implement faster transfer
|
||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||
self.hf_dataset = load_hf_dataset(repo_id, CODEBASE_VERSION, self.root, split)
|
||||
self.hf_dataset = load_hf_dataset(repo_id, CODEBASE_VERSION, root, split)
|
||||
if split == "train":
|
||||
self.episode_data_index = load_episode_data_index(repo_id, CODEBASE_VERSION, root)
|
||||
else:
|
||||
self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
|
||||
self.hf_dataset = reset_episode_index(self.hf_dataset)
|
||||
self.stats = load_stats(repo_id, CODEBASE_VERSION, self.root)
|
||||
self.info = load_info(repo_id, CODEBASE_VERSION, self.root)
|
||||
self.stats = load_stats(repo_id, CODEBASE_VERSION, root)
|
||||
self.info = load_info(repo_id, CODEBASE_VERSION, root)
|
||||
if self.video:
|
||||
self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, self.root)
|
||||
self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root)
|
||||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
|
||||
@property
|
||||
|
||||
@@ -23,11 +23,19 @@ from typing import Dict
|
||||
import datasets
|
||||
import torch
|
||||
from datasets import load_dataset, load_from_disk
|
||||
from huggingface_hub import HfApi, hf_hub_download, snapshot_download
|
||||
from huggingface_hub import DatasetCard, HfApi, hf_hub_download, snapshot_download
|
||||
from PIL import Image as PILImage
|
||||
from safetensors.torch import load_file
|
||||
from torchvision import transforms
|
||||
|
||||
DATASET_CARD_TEMPLATE = """
|
||||
---
|
||||
# Metadata will go there
|
||||
---
|
||||
This dataset was created using [🤗 LeRobot](https://github.com/huggingface/lerobot).
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def flatten_dict(d, parent_key="", sep="/"):
|
||||
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
|
||||
@@ -385,3 +393,29 @@ def cycle(iterable):
|
||||
yield next(iterator)
|
||||
except StopIteration:
|
||||
iterator = iter(iterable)
|
||||
|
||||
|
||||
def create_branch(repo_id, *, branch: str, repo_type: str | None = None):
|
||||
"""Create a branch on a existing Hugging Face repo. Delete the branch if it already
|
||||
exists before creating it.
|
||||
"""
|
||||
api = HfApi()
|
||||
|
||||
branches = api.list_repo_refs(repo_id, repo_type=repo_type).branches
|
||||
refs = [branch.ref for branch in branches]
|
||||
ref = f"refs/heads/{branch}"
|
||||
if ref in refs:
|
||||
api.delete_branch(repo_id, repo_type=repo_type, branch=branch)
|
||||
|
||||
api.create_branch(repo_id, repo_type=repo_type, branch=branch)
|
||||
|
||||
|
||||
def create_lerobot_dataset_card(tags: list | None = None, text: str | None = None) -> DatasetCard:
|
||||
card = DatasetCard(DATASET_CARD_TEMPLATE)
|
||||
card.data.task_categories = ["robotics"]
|
||||
card.data.tags = ["LeRobot"]
|
||||
if tags is not None:
|
||||
card.data.tags += tags
|
||||
if text is not None:
|
||||
card.text += text
|
||||
return card
|
||||
|
||||
@@ -210,6 +210,12 @@ def encode_video_frames(
|
||||
# redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
|
||||
subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
|
||||
|
||||
if not video_path.exists():
|
||||
raise OSError(
|
||||
f"Video encoding did not work. File not found: {video_path}. "
|
||||
f"Try running the command manually to debug: `{''.join(ffmpeg_cmd)}`"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoFrame:
|
||||
|
||||
@@ -233,6 +233,9 @@ class Logger:
|
||||
if self._wandb is not None:
|
||||
for k, v in d.items():
|
||||
if not isinstance(v, (int, float, str)):
|
||||
logging.warning(
|
||||
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
||||
)
|
||||
continue
|
||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||
|
||||
|
||||
@@ -38,7 +38,13 @@ from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
|
||||
|
||||
class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
class ACTPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "act"],
|
||||
):
|
||||
"""
|
||||
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
|
||||
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
|
||||
@@ -134,26 +140,25 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
bsize = actions_hat.shape[0]
|
||||
l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none")
|
||||
l1_loss = l1_loss * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
l1_loss = l1_loss.view(bsize, -1).mean(dim=1)
|
||||
|
||||
out_dict = {}
|
||||
out_dict["l1_loss"] = l1_loss
|
||||
l1_loss = (
|
||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
).mean()
|
||||
|
||||
loss_dict = {"l1_loss": l1_loss.item()}
|
||||
if self.config.use_vae:
|
||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||
# each dimension independently, we sum over the latent dimension to get the total
|
||||
# KL-divergence per batch element, then take the mean over the batch.
|
||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||
kld_loss = (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1)
|
||||
out_dict["loss"] = l1_loss + kld_loss * self.config.kl_weight
|
||||
mean_kld = (
|
||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||
)
|
||||
loss_dict["kld_loss"] = mean_kld.item()
|
||||
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
|
||||
else:
|
||||
out_dict["loss"] = l1_loss
|
||||
loss_dict["loss"] = l1_loss
|
||||
|
||||
out_dict["action"] = self.unnormalize_outputs({"action": actions_hat})["action"]
|
||||
return out_dict
|
||||
return loss_dict
|
||||
|
||||
|
||||
class ACTTemporalEnsembler:
|
||||
|
||||
@@ -43,7 +43,13 @@ from lerobot.common.policies.utils import (
|
||||
)
|
||||
|
||||
|
||||
class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
class DiffusionPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "diffusion-policy"],
|
||||
):
|
||||
"""
|
||||
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
|
||||
(paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy).
|
||||
@@ -341,11 +347,7 @@ class DiffusionModel(nn.Module):
|
||||
in_episode_bound = ~batch["action_is_pad"]
|
||||
loss = loss * in_episode_bound.unsqueeze(-1)
|
||||
|
||||
# Compute average per item in the batch
|
||||
bsize = loss.shape[0]
|
||||
loss = loss.reshape(bsize, -1).mean(1)
|
||||
|
||||
return loss
|
||||
return loss.mean()
|
||||
|
||||
|
||||
class SpatialSoftmax(nn.Module):
|
||||
|
||||
@@ -41,7 +41,13 @@ from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
|
||||
|
||||
|
||||
class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
class TDMPCPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "tdmpc"],
|
||||
):
|
||||
"""Implementation of TD-MPC learning + inference.
|
||||
|
||||
Please note several warnings for this policy.
|
||||
@@ -396,39 +402,51 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
# Compute consistency loss as MSE loss between latents predicted from the rollout and latents
|
||||
# predicted from the (target model's) observation encoder.
|
||||
consistency_loss = (
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1)
|
||||
# `z_preds` depends on the current observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
# `z_targets` depends on the next observation.
|
||||
* ~batch["observation.state_is_pad"][1:]
|
||||
).sum(0)
|
||||
(
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1)
|
||||
# `z_preds` depends on the current observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
# `z_targets` depends on the next observation.
|
||||
* ~batch["observation.state_is_pad"][1:]
|
||||
)
|
||||
.sum(0)
|
||||
.mean()
|
||||
)
|
||||
# Compute the reward loss as MSE loss between rewards predicted from the rollout and the dataset
|
||||
# rewards.
|
||||
reward_loss = (
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(reward_preds, reward, reduction="none")
|
||||
* ~batch["next.reward_is_pad"]
|
||||
# `reward_preds` depends on the current observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
).sum(0)
|
||||
(
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(reward_preds, reward, reduction="none")
|
||||
* ~batch["next.reward_is_pad"]
|
||||
# `reward_preds` depends on the current observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
)
|
||||
.sum(0)
|
||||
.mean()
|
||||
)
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
q_value_loss = (
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(
|
||||
q_preds_ensemble,
|
||||
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
|
||||
reduction="none",
|
||||
).sum(0) # sum over ensemble
|
||||
# `q_preds_ensemble` depends on the first observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
# q_targets depends on the reward and the next observations.
|
||||
* ~batch["next.reward_is_pad"]
|
||||
* ~batch["observation.state_is_pad"][1:]
|
||||
).sum(0)
|
||||
(
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(
|
||||
q_preds_ensemble,
|
||||
einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]),
|
||||
reduction="none",
|
||||
).sum(0) # sum over ensemble
|
||||
# `q_preds_ensemble` depends on the first observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
# q_targets depends on the reward and the next observations.
|
||||
* ~batch["next.reward_is_pad"]
|
||||
* ~batch["observation.state_is_pad"][1:]
|
||||
)
|
||||
.sum(0)
|
||||
.mean()
|
||||
)
|
||||
# Compute state value loss as in eqn 3 of FOWM.
|
||||
diff = v_targets - v_preds
|
||||
# Expectile loss penalizes:
|
||||
@@ -438,12 +456,16 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
diff > 0, self.config.expectile_weight, (1 - self.config.expectile_weight)
|
||||
) * (diff**2)
|
||||
v_value_loss = (
|
||||
temporal_loss_coeffs
|
||||
* raw_v_value_loss
|
||||
# `v_targets` depends on the first observation and the actions, as does `v_preds`.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
).sum(0)
|
||||
(
|
||||
temporal_loss_coeffs
|
||||
* raw_v_value_loss
|
||||
# `v_targets` depends on the first observation and the actions, as does `v_preds`.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
)
|
||||
.sum(0)
|
||||
.mean()
|
||||
)
|
||||
|
||||
# Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1.
|
||||
# We won't need these gradients again so detach.
|
||||
@@ -476,7 +498,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
# `action_preds` depends on the first observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
).sum(0)
|
||||
).mean()
|
||||
|
||||
loss = (
|
||||
self.config.consistency_coeff * consistency_loss
|
||||
@@ -488,13 +510,13 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
|
||||
info.update(
|
||||
{
|
||||
"consistency_loss": consistency_loss,
|
||||
"reward_loss": reward_loss,
|
||||
"Q_value_loss": q_value_loss,
|
||||
"V_value_loss": v_value_loss,
|
||||
"pi_loss": pi_loss,
|
||||
"consistency_loss": consistency_loss.item(),
|
||||
"reward_loss": reward_loss.item(),
|
||||
"Q_value_loss": q_value_loss.item(),
|
||||
"V_value_loss": v_value_loss.item(),
|
||||
"pi_loss": pi_loss.item(),
|
||||
"loss": loss,
|
||||
"sum_loss": loss * self.config.horizon,
|
||||
"sum_loss": loss.item() * self.config.horizon,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -13,13 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||
from huggingface_hub.utils._validators import HFValidationError
|
||||
from torch import nn
|
||||
|
||||
|
||||
@@ -53,26 +47,3 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
|
||||
Note: assumes that all parameters have the same dtype.
|
||||
"""
|
||||
return next(iter(module.parameters())).dtype
|
||||
|
||||
|
||||
def get_pretrained_policy_path(pretrained_policy_name_or_path, revision=None):
|
||||
try:
|
||||
pretrained_policy_path = Path(snapshot_download(pretrained_policy_name_or_path, revision=revision))
|
||||
except (HFValidationError, RepositoryNotFoundError) as e:
|
||||
if isinstance(e, HFValidationError):
|
||||
error_message = (
|
||||
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
|
||||
)
|
||||
else:
|
||||
error_message = (
|
||||
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
|
||||
)
|
||||
|
||||
logging.warning(f"{error_message} Treating it as a local directory.")
|
||||
pretrained_policy_path = Path(pretrained_policy_name_or_path)
|
||||
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
|
||||
raise ValueError(
|
||||
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
|
||||
"repo ID, nor is it an existing local directory."
|
||||
)
|
||||
return pretrained_policy_path
|
||||
|
||||
@@ -38,7 +38,13 @@ from lerobot.common.policies.vqbet.vqbet_utils import GPT, ResidualVQ
|
||||
# ruff: noqa: N806
|
||||
|
||||
|
||||
class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
class VQBeTPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "vqbet"],
|
||||
):
|
||||
"""
|
||||
VQ-BeT Policy as per "Behavior Generation with Latent Actions"
|
||||
"""
|
||||
@@ -289,7 +295,7 @@ class VQBeTModel(nn.Module):
|
||||
|
||||
# To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT.
|
||||
self.state_projector = MLP(
|
||||
config.output_shapes["action"][0], hidden_channels=[self.config.gpt_input_dim]
|
||||
config.input_shapes["observation.state"][0], hidden_channels=[self.config.gpt_input_dim]
|
||||
)
|
||||
self.rgb_feature_projector = MLP(
|
||||
self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim]
|
||||
|
||||
@@ -5,6 +5,7 @@ This file contains utilities for recording frames from cameras. For more info lo
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
import math
|
||||
import platform
|
||||
import shutil
|
||||
import threading
|
||||
import time
|
||||
@@ -33,8 +34,22 @@ MAX_OPENCV_INDEX = 60
|
||||
|
||||
|
||||
def find_camera_indices(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX):
|
||||
if platform.system() == "Linux":
|
||||
# Linux uses camera ports
|
||||
print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports")
|
||||
possible_camera_ids = []
|
||||
for port in Path("/dev").glob("video*"):
|
||||
camera_idx = int(str(port).replace("/dev/video", ""))
|
||||
possible_camera_ids.append(camera_idx)
|
||||
else:
|
||||
print(
|
||||
"Mac or Windows detected. Finding available camera indices through "
|
||||
f"scanning all indices from 0 to {MAX_OPENCV_INDEX}"
|
||||
)
|
||||
possible_camera_ids = range(max_index_search_range)
|
||||
|
||||
camera_ids = []
|
||||
for camera_idx in range(max_index_search_range):
|
||||
for camera_idx in possible_camera_ids:
|
||||
camera = cv2.VideoCapture(camera_idx)
|
||||
is_open = camera.isOpened()
|
||||
camera.release()
|
||||
@@ -45,7 +60,8 @@ def find_camera_indices(raise_when_empty=False, max_index_search_range=MAX_OPENC
|
||||
|
||||
if raise_when_empty and len(camera_ids) == 0:
|
||||
raise OSError(
|
||||
"Not a single camera was detected. Try re-plugging, or re-installing `opencv2`, or your camera driver, or make sure your camera is compatible with opencv2."
|
||||
"Not a single camera was detected. Try re-plugging, or re-installing `opencv2`, "
|
||||
"or your camera driver, or make sure your camera is compatible with opencv2."
|
||||
)
|
||||
|
||||
return camera_ids
|
||||
@@ -59,10 +75,9 @@ def save_image(img_array, camera_index, frame_index, images_dir):
|
||||
|
||||
|
||||
def save_images_from_cameras(
|
||||
images_dir: Path, camera_ids=None, fps=None, width=None, height=None, record_time_s=2
|
||||
images_dir: Path, camera_ids: list[int] | None = None, fps=None, width=None, height=None, record_time_s=2
|
||||
):
|
||||
if camera_ids is None:
|
||||
print("Finding available camera indices")
|
||||
camera_ids = find_camera_indices()
|
||||
|
||||
print("Connecting cameras")
|
||||
@@ -71,13 +86,12 @@ def save_images_from_cameras(
|
||||
camera = OpenCVCamera(cam_idx, fps=fps, width=width, height=height)
|
||||
camera.connect()
|
||||
print(
|
||||
f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.width}, height={camera.height}, color_mode={camera.color_mode})"
|
||||
f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.width}, "
|
||||
f"height={camera.height}, color_mode={camera.color_mode})"
|
||||
)
|
||||
cameras.append(camera)
|
||||
|
||||
images_dir = Path(
|
||||
images_dir,
|
||||
)
|
||||
images_dir = Path(images_dir)
|
||||
if images_dir.exists():
|
||||
shutil.rmtree(
|
||||
images_dir,
|
||||
@@ -160,7 +174,7 @@ class OpenCVCamera:
|
||||
When an OpenCVCamera is instantiated, if no specific config is provided, the default fps, width, height and color_mode
|
||||
of the given camera will be used.
|
||||
|
||||
Example of usage of the class:
|
||||
Example of usage:
|
||||
```python
|
||||
camera = OpenCVCamera(camera_index=0)
|
||||
camera.connect()
|
||||
@@ -194,11 +208,6 @@ class OpenCVCamera:
|
||||
self.height = config.height
|
||||
self.color_mode = config.color_mode
|
||||
|
||||
if not isinstance(self.camera_index, int):
|
||||
raise ValueError(
|
||||
f"Camera index must be provided as an int, but {self.camera_index} was given instead."
|
||||
)
|
||||
|
||||
self.camera = None
|
||||
self.is_connected = False
|
||||
self.thread = None
|
||||
@@ -212,7 +221,13 @@ class OpenCVCamera:
|
||||
|
||||
# First create a temporary camera trying to access `camera_index`,
|
||||
# and verify it is a valid camera by calling `isOpened`.
|
||||
tmp_camera = cv2.VideoCapture(self.camera_index)
|
||||
|
||||
if platform.system() == "Linux":
|
||||
# Linux uses ports for connecting to cameras
|
||||
tmp_camera = cv2.VideoCapture(f"/dev/video{self.camera_index}")
|
||||
else:
|
||||
tmp_camera = cv2.VideoCapture(self.camera_index)
|
||||
|
||||
is_camera_open = tmp_camera.isOpened()
|
||||
# Release camera to make it accessible for `find_camera_indices`
|
||||
del tmp_camera
|
||||
@@ -224,7 +239,8 @@ class OpenCVCamera:
|
||||
available_cam_ids = find_camera_indices()
|
||||
if self.camera_index not in available_cam_ids:
|
||||
raise ValueError(
|
||||
f"`camera_index` is expected to be one of these available cameras {available_cam_ids}, but {self.camera_index} is provided instead."
|
||||
f"`camera_index` is expected to be one of these available cameras {available_cam_ids}, but {self.camera_index} is provided instead. "
|
||||
"To find the camera index you should use, run `python lerobot/common/robot_devices/cameras/opencv.py`."
|
||||
)
|
||||
|
||||
raise OSError(f"Can't access camera {self.camera_index}.")
|
||||
@@ -232,7 +248,10 @@ class OpenCVCamera:
|
||||
# Secondly, create the camera that will be used downstream.
|
||||
# Note: For some unknown reason, calling `isOpened` blocks the camera which then
|
||||
# needs to be re-created.
|
||||
self.camera = cv2.VideoCapture(self.camera_index)
|
||||
if platform.system() == "Linux":
|
||||
self.camera = cv2.VideoCapture(f"/dev/video{self.camera_index}")
|
||||
else:
|
||||
self.camera = cv2.VideoCapture(self.camera_index)
|
||||
|
||||
if self.fps is not None:
|
||||
self.camera.set(cv2.CAP_PROP_FPS, self.fps)
|
||||
|
||||
@@ -2,6 +2,7 @@ from pathlib import Path
|
||||
from typing import Protocol
|
||||
|
||||
import cv2
|
||||
import einops
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -39,6 +40,16 @@ def save_depth_image(depth, path, write_shape=False):
|
||||
cv2.imwrite(str(path), depth_image)
|
||||
|
||||
|
||||
def convert_torch_image_to_cv2(tensor, rgb_to_bgr=True):
|
||||
assert tensor.ndim == 3
|
||||
c, h, w = tensor.shape
|
||||
assert c < h and c < w
|
||||
color_image = einops.rearrange(tensor, "c h w -> h w c").numpy()
|
||||
if rgb_to_bgr:
|
||||
color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR)
|
||||
return color_image
|
||||
|
||||
|
||||
# Defines a camera type
|
||||
class Camera(Protocol):
|
||||
def connect(self): ...
|
||||
|
||||
@@ -5,6 +5,7 @@ from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import tqdm
|
||||
from dynamixel_sdk import (
|
||||
COMM_SUCCESS,
|
||||
DXL_HIBYTE,
|
||||
@@ -21,9 +22,11 @@ from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError,
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
|
||||
PROTOCOL_VERSION = 2.0
|
||||
BAUD_RATE = 1_000_000
|
||||
BAUDRATE = 1_000_000
|
||||
TIMEOUT_MS = 1000
|
||||
|
||||
MAX_ID_RANGE = 252
|
||||
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/xl330-m077
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/xl330-m288
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/xl430-w250
|
||||
@@ -86,6 +89,16 @@ X_SERIES_CONTROL_TABLE = {
|
||||
"Present_Temperature": (146, 1),
|
||||
}
|
||||
|
||||
X_SERIES_BAUDRATE_TABLE = {
|
||||
0: 9_600,
|
||||
1: 57_600,
|
||||
2: 115_200,
|
||||
3: 1_000_000,
|
||||
4: 2_000_000,
|
||||
5: 3_000_000,
|
||||
6: 4_000_000,
|
||||
}
|
||||
|
||||
CALIBRATION_REQUIRED = ["Goal_Position", "Present_Position"]
|
||||
CONVERT_UINT32_TO_INT32_REQUIRED = ["Goal_Position", "Present_Position"]
|
||||
|
||||
@@ -98,7 +111,67 @@ MODEL_CONTROL_TABLE = {
|
||||
"xm540-w270": X_SERIES_CONTROL_TABLE,
|
||||
}
|
||||
|
||||
MODEL_RESOLUTION = {
|
||||
"x_series": 4096,
|
||||
"xl330-m077": 4096,
|
||||
"xl330-m288": 4096,
|
||||
"xl430-w250": 4096,
|
||||
"xm430-w350": 4096,
|
||||
"xm540-w270": 4096,
|
||||
}
|
||||
|
||||
MODEL_BAUDRATE_TABLE = {
|
||||
"x_series": X_SERIES_BAUDRATE_TABLE,
|
||||
"xl330-m077": X_SERIES_BAUDRATE_TABLE,
|
||||
"xl330-m288": X_SERIES_BAUDRATE_TABLE,
|
||||
"xl430-w250": X_SERIES_BAUDRATE_TABLE,
|
||||
"xm430-w350": X_SERIES_BAUDRATE_TABLE,
|
||||
"xm540-w270": X_SERIES_BAUDRATE_TABLE,
|
||||
}
|
||||
|
||||
NUM_READ_RETRY = 10
|
||||
NUM_WRITE_RETRY = 10
|
||||
|
||||
|
||||
def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]):
|
||||
"""This function convert the degree range to the step range for indicating motors rotation.
|
||||
It assums a motor achieves a full rotation by going from -180 degree position to +180.
|
||||
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
|
||||
"""
|
||||
if isinstance(degrees, float):
|
||||
degrees = np.array(degrees)
|
||||
|
||||
resolutions = [MODEL_RESOLUTION[model] for model in models]
|
||||
steps = degrees / 180 * np.array(resolutions) / 2
|
||||
steps = steps.astype(int)
|
||||
return steps
|
||||
|
||||
|
||||
def convert_to_bytes(value, bytes):
|
||||
# Note: No need to convert back into unsigned int, since this byte preprocessing
|
||||
# already handles it for us.
|
||||
if bytes == 1:
|
||||
data = [
|
||||
DXL_LOBYTE(DXL_LOWORD(value)),
|
||||
]
|
||||
elif bytes == 2:
|
||||
data = [
|
||||
DXL_LOBYTE(DXL_LOWORD(value)),
|
||||
DXL_HIBYTE(DXL_LOWORD(value)),
|
||||
]
|
||||
elif bytes == 4:
|
||||
data = [
|
||||
DXL_LOBYTE(DXL_LOWORD(value)),
|
||||
DXL_HIBYTE(DXL_LOWORD(value)),
|
||||
DXL_LOBYTE(DXL_HIWORD(value)),
|
||||
DXL_HIBYTE(DXL_HIWORD(value)),
|
||||
]
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
|
||||
f"{bytes} is provided instead."
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
def get_group_sync_key(data_name, motor_names):
|
||||
@@ -207,13 +280,12 @@ class DynamixelMotorsBus:
|
||||
>>> The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751.
|
||||
>>> Reconnect the usb cable.
|
||||
```
|
||||
To find the motor indices, use [DynamixelWizzard2](https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_wizard2).
|
||||
|
||||
Example of usage for 1 motor connected to the bus:
|
||||
```python
|
||||
motor_name = "gripper"
|
||||
motor_index = 6
|
||||
motor_model = "xl330-m077"
|
||||
motor_model = "xl330-m288"
|
||||
|
||||
motors_bus = DynamixelMotorsBus(
|
||||
port="/dev/tty.usbmodem575E0031751",
|
||||
@@ -221,7 +293,11 @@ class DynamixelMotorsBus:
|
||||
)
|
||||
motors_bus.connect()
|
||||
|
||||
motors_bus.teleop_step()
|
||||
position = motors_bus.read("Present_Position")
|
||||
|
||||
# move from a few motor steps as an example
|
||||
few_steps = 30
|
||||
motors_bus.write("Goal_Position", position + few_steps)
|
||||
|
||||
# when done, consider disconnecting
|
||||
motors_bus.disconnect()
|
||||
@@ -233,6 +309,7 @@ class DynamixelMotorsBus:
|
||||
port: str,
|
||||
motors: dict[str, tuple[int, str]],
|
||||
extra_model_control_table: dict[str, list[tuple]] | None = None,
|
||||
extra_model_resolution: dict[str, int] | None = None,
|
||||
):
|
||||
self.port = port
|
||||
self.motors = motors
|
||||
@@ -241,6 +318,10 @@ class DynamixelMotorsBus:
|
||||
if extra_model_control_table:
|
||||
self.model_ctrl_table.update(extra_model_control_table)
|
||||
|
||||
self.model_resolution = deepcopy(MODEL_RESOLUTION)
|
||||
if extra_model_resolution:
|
||||
self.model_resolution.update(extra_model_resolution)
|
||||
|
||||
self.port_handler = None
|
||||
self.packet_handler = None
|
||||
self.calibration = None
|
||||
@@ -268,52 +349,286 @@ class DynamixelMotorsBus:
|
||||
)
|
||||
raise
|
||||
|
||||
self.port_handler.setBaudRate(BAUD_RATE)
|
||||
self.port_handler.setPacketTimeoutMillis(TIMEOUT_MS)
|
||||
|
||||
# Allow to read and write
|
||||
self.is_connected = True
|
||||
|
||||
self.port_handler.setPacketTimeoutMillis(TIMEOUT_MS)
|
||||
|
||||
# Set expected baudrate for the bus
|
||||
self.set_bus_baudrate(BAUDRATE)
|
||||
|
||||
if not self.are_motors_configured():
|
||||
input(
|
||||
"\n/!\\ A configuration issue has been detected with your motors: \n"
|
||||
"If it's the first time that you use these motors, press enter to configure your motors... but before "
|
||||
"verify that all the cables are connected the proper way. If you find an issue, before making a modification, "
|
||||
"kill the python process, unplug the power cord to not damage the motors, rewire correctly, then plug the power "
|
||||
"again and relaunch the script.\n"
|
||||
)
|
||||
print()
|
||||
self.configure_motors()
|
||||
|
||||
def reconnect(self):
|
||||
self.port_handler = PortHandler(self.port)
|
||||
self.packet_handler = PacketHandler(PROTOCOL_VERSION)
|
||||
if not self.port_handler.openPort():
|
||||
raise OSError(f"Failed to open port '{self.port}'.")
|
||||
self.is_connected = True
|
||||
|
||||
def are_motors_configured(self):
|
||||
# Only check the motor indices and not baudrate, since if the motor baudrates are incorrect,
|
||||
# a ConnectionError will be raised anyway.
|
||||
try:
|
||||
return (self.motor_indices == self.read("ID")).all()
|
||||
except ConnectionError as e:
|
||||
print(e)
|
||||
return False
|
||||
|
||||
def configure_motors(self):
|
||||
# TODO(rcadene): This script assumes motors follow the X_SERIES baudrates
|
||||
# TODO(rcadene): Refactor this function with intermediate high-level functions
|
||||
|
||||
print("Scanning all baudrates and motor indices")
|
||||
all_baudrates = set(X_SERIES_BAUDRATE_TABLE.values())
|
||||
ids_per_baudrate = {}
|
||||
for baudrate in all_baudrates:
|
||||
self.set_bus_baudrate(baudrate)
|
||||
present_ids = self.find_motor_indices()
|
||||
if len(present_ids) > 0:
|
||||
ids_per_baudrate[baudrate] = present_ids
|
||||
print(f"Motor indices detected: {ids_per_baudrate}")
|
||||
print()
|
||||
|
||||
possible_baudrates = list(ids_per_baudrate.keys())
|
||||
possible_ids = list({idx for sublist in ids_per_baudrate.values() for idx in sublist})
|
||||
untaken_ids = list(set(range(MAX_ID_RANGE)) - set(possible_ids) - set(self.motor_indices))
|
||||
|
||||
# Connect successively one motor to the chain and write a unique random index for each
|
||||
for i in range(len(self.motors)):
|
||||
self.disconnect()
|
||||
input(
|
||||
"1. Unplug the power cord\n"
|
||||
"2. Plug/unplug minimal number of cables to only have the first "
|
||||
f"{i+1} motor(s) ({self.motor_names[:i+1]}) connected.\n"
|
||||
"3. Re-plug the power cord\n"
|
||||
"Press Enter to continue..."
|
||||
)
|
||||
print()
|
||||
self.reconnect()
|
||||
|
||||
if i > 0:
|
||||
try:
|
||||
self._read_with_motor_ids(self.motor_models, untaken_ids[:i], "ID")
|
||||
except ConnectionError:
|
||||
print(f"Failed to read from {untaken_ids[:i+1]}. Make sure the power cord is plugged in.")
|
||||
input("Press Enter to continue...")
|
||||
print()
|
||||
self.reconnect()
|
||||
|
||||
print("Scanning possible baudrates and motor indices")
|
||||
motor_found = False
|
||||
for baudrate in possible_baudrates:
|
||||
self.set_bus_baudrate(baudrate)
|
||||
present_ids = self.find_motor_indices(possible_ids)
|
||||
if len(present_ids) == 1:
|
||||
present_idx = present_ids[0]
|
||||
print(f"Detected motor with index {present_idx}")
|
||||
|
||||
if baudrate != BAUDRATE:
|
||||
print(f"Setting its baudrate to {BAUDRATE}")
|
||||
baudrate_idx = list(X_SERIES_BAUDRATE_TABLE.values()).index(BAUDRATE)
|
||||
|
||||
# The write can fail, so we allow retries
|
||||
for _ in range(NUM_WRITE_RETRY):
|
||||
self._write_with_motor_ids(
|
||||
self.motor_models, present_idx, "Baud_Rate", baudrate_idx
|
||||
)
|
||||
time.sleep(0.5)
|
||||
self.set_bus_baudrate(BAUDRATE)
|
||||
try:
|
||||
present_baudrate_idx = self._read_with_motor_ids(
|
||||
self.motor_models, present_idx, "Baud_Rate"
|
||||
)
|
||||
except ConnectionError:
|
||||
print("Failed to write baudrate. Retrying.")
|
||||
self.set_bus_baudrate(baudrate)
|
||||
continue
|
||||
break
|
||||
else:
|
||||
raise
|
||||
|
||||
if present_baudrate_idx != baudrate_idx:
|
||||
raise OSError("Failed to write baudrate.")
|
||||
|
||||
print(f"Setting its index to a temporary untaken index ({untaken_ids[i]})")
|
||||
self._write_with_motor_ids(self.motor_models, present_idx, "ID", untaken_ids[i])
|
||||
|
||||
present_idx = self._read_with_motor_ids(self.motor_models, untaken_ids[i], "ID")
|
||||
if present_idx != untaken_ids[i]:
|
||||
raise OSError("Failed to write index.")
|
||||
|
||||
motor_found = True
|
||||
break
|
||||
elif len(present_ids) > 1:
|
||||
raise OSError(f"More than one motor detected ({present_ids}), but only one was expected.")
|
||||
|
||||
if not motor_found:
|
||||
raise OSError(
|
||||
"No motor found, but one new motor expected. Verify power cord is plugged in and retry."
|
||||
)
|
||||
print()
|
||||
|
||||
print(f"Setting expected motor indices: {self.motor_indices}")
|
||||
self.set_bus_baudrate(BAUDRATE)
|
||||
self._write_with_motor_ids(
|
||||
self.motor_models, untaken_ids[: len(self.motors)], "ID", self.motor_indices
|
||||
)
|
||||
print()
|
||||
|
||||
if (self.read("ID") != self.motor_indices).any():
|
||||
raise OSError("Failed to write motors indices.")
|
||||
|
||||
print("Configuration is done!")
|
||||
|
||||
def find_motor_indices(self, possible_ids=None):
|
||||
if possible_ids is None:
|
||||
possible_ids = range(MAX_ID_RANGE)
|
||||
|
||||
indices = []
|
||||
for idx in tqdm.tqdm(possible_ids):
|
||||
try:
|
||||
present_idx = self._read_with_motor_ids(self.motor_models, [idx], "ID")[0]
|
||||
except ConnectionError:
|
||||
continue
|
||||
|
||||
if idx != present_idx:
|
||||
# sanity check
|
||||
raise OSError(
|
||||
"Motor index used to communicate through the bus is not the same as the one present in the motor memory. The motor memory might be damaged."
|
||||
)
|
||||
indices.append(idx)
|
||||
|
||||
return indices
|
||||
|
||||
def set_bus_baudrate(self, baudrate):
|
||||
present_bus_baudrate = self.port_handler.getBaudRate()
|
||||
if present_bus_baudrate != baudrate:
|
||||
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
|
||||
self.port_handler.setBaudRate(baudrate)
|
||||
|
||||
if self.port_handler.getBaudRate() != baudrate:
|
||||
raise OSError("Failed to write bus baud rate.")
|
||||
|
||||
@property
|
||||
def motor_names(self) -> list[int]:
|
||||
def motor_names(self) -> list[str]:
|
||||
return list(self.motors.keys())
|
||||
|
||||
@property
|
||||
def motor_models(self) -> list[str]:
|
||||
return [model for _, model in self.motors.values()]
|
||||
|
||||
@property
|
||||
def motor_indices(self) -> list[int]:
|
||||
return [idx for idx, _ in self.motors.values()]
|
||||
|
||||
def set_calibration(self, calibration: dict[str, tuple[int, bool]]):
|
||||
self.calibration = calibration
|
||||
|
||||
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
if not self.calibration:
|
||||
return values
|
||||
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
|
||||
a "zero position" at 0 degree.
|
||||
|
||||
Note: We say "nominal degree range" since the motors can take values outside this range. For instance, 190 degrees, if the motor
|
||||
rotate more than a half a turn from the zero position. However, most motors can't rotate more than 180 degrees and will stay in this range.
|
||||
|
||||
Joints values are original in [0, 2**32[ (unsigned int32). Each motor are expected to complete a full rotation
|
||||
when given a goal position that is + or - their resolution. For instance, dynamixel xl330-m077 have a resolution of 4096, and
|
||||
at any position in their original range, let's say the position 56734, they complete a full rotation clockwise by moving to 60830,
|
||||
or anticlockwise by moving to 52638. The position in the original range is arbitrary and might change a lot between each motor.
|
||||
To harmonize between motors of the same model, different robots, or even models of different brands, we propose to work
|
||||
in the centered nominal degree range ]-180, 180[.
|
||||
"""
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
|
||||
# Convert from unsigned int32 original range [0, 2**32[ to centered signed int32 range [-2**31, 2**31[
|
||||
values = values.astype(np.int32)
|
||||
|
||||
for i, name in enumerate(motor_names):
|
||||
homing_offset, drive_mode = self.calibration[name]
|
||||
|
||||
if values[i] is not None:
|
||||
if drive_mode:
|
||||
values[i] *= -1
|
||||
values[i] += homing_offset
|
||||
# Update direction of rotation of the motor to match between leader and follower. In fact, the motor of the leader for a given joint
|
||||
# can be assembled in an opposite direction in term of rotation than the motor of the follower on the same joint.
|
||||
if drive_mode:
|
||||
values[i] *= -1
|
||||
|
||||
# Convert from range [-2**31, 2**31[ to nominal range ]-resolution, resolution[ (e.g. ]-2048, 2048[)
|
||||
values[i] += homing_offset
|
||||
|
||||
# Convert from range ]-resolution, resolution[ to the universal float32 centered degree range ]-180, 180[
|
||||
values = values.astype(np.float32)
|
||||
for i, name in enumerate(motor_names):
|
||||
_, model = self.motors[name]
|
||||
resolution = self.model_resolution[model]
|
||||
values[i] = values[i] / (resolution // 2) * 180
|
||||
|
||||
return values
|
||||
|
||||
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
if not self.calibration:
|
||||
return values
|
||||
|
||||
"""Inverse of `apply_calibration`."""
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
|
||||
# Convert from the universal float32 centered degree range ]-180, 180[ to resolution range ]-resolution, resolution[
|
||||
for i, name in enumerate(motor_names):
|
||||
_, model = self.motors[name]
|
||||
resolution = self.model_resolution[model]
|
||||
values[i] = values[i] / 180 * (resolution // 2)
|
||||
|
||||
values = np.round(values).astype(np.int32)
|
||||
|
||||
# Convert from nominal range ]-resolution, resolution[ to centered signed int32 range [-2**31, 2**31[
|
||||
for i, name in enumerate(motor_names):
|
||||
homing_offset, drive_mode = self.calibration[name]
|
||||
values[i] -= homing_offset
|
||||
|
||||
if values[i] is not None:
|
||||
values[i] -= homing_offset
|
||||
if drive_mode:
|
||||
values[i] *= -1
|
||||
# Update direction of rotation of the motor that was matching between leader and follower to their original direction.
|
||||
# In fact, the motor of the leader for a given joint can be assembled in an opposite direction in term of rotation
|
||||
# than the motor of the follower on the same joint.
|
||||
if drive_mode:
|
||||
values[i] *= -1
|
||||
|
||||
return values
|
||||
|
||||
def _read_with_motor_ids(self, motor_models, motor_ids, data_name):
|
||||
return_list = True
|
||||
if not isinstance(motor_ids, list):
|
||||
return_list = False
|
||||
motor_ids = [motor_ids]
|
||||
|
||||
assert_same_address(self.model_ctrl_table, self.motor_models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
|
||||
group = GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes)
|
||||
for idx in motor_ids:
|
||||
group.addParam(idx)
|
||||
|
||||
comm = group.txRxPacket()
|
||||
if comm != COMM_SUCCESS:
|
||||
raise ConnectionError(
|
||||
f"Read failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: "
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
values = []
|
||||
for idx in motor_ids:
|
||||
value = group.getData(idx, addr, bytes)
|
||||
values.append(value)
|
||||
|
||||
if return_list:
|
||||
return values
|
||||
else:
|
||||
return values[0]
|
||||
|
||||
def read(self, data_name, motor_names: str | list[str] | None = None):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
@@ -367,9 +682,21 @@ class DynamixelMotorsBus:
|
||||
if data_name in CONVERT_UINT32_TO_INT32_REQUIRED:
|
||||
values = values.astype(np.int32)
|
||||
|
||||
if data_name in CALIBRATION_REQUIRED:
|
||||
if data_name in CALIBRATION_REQUIRED and self.calibration is not None:
|
||||
values = self.apply_calibration(values, motor_names)
|
||||
|
||||
# We expect our motors to stay in a nominal range of [-180, 180] degrees
|
||||
# which corresponds to a half turn rotation.
|
||||
# However, some motors can turn a bit more, hence we extend the nominal range to [-270, 270]
|
||||
# which is less than a full 360 degree rotation.
|
||||
if not np.all((values > -270) & (values < 270)):
|
||||
raise ValueError(
|
||||
f"Wrong motor position range detected. "
|
||||
f"Expected to be in [-270, +270] but in [{values.min()}, {values.max()}]. "
|
||||
"This might be due to a cable connection issue creating an artificial 360 degrees jump in motor values. "
|
||||
"You need to recalibrate by running: `python lerobot/scripts/control_robot.py calibrate`"
|
||||
)
|
||||
|
||||
# log the number of seconds it took to read the data from the motors
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
|
||||
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
||||
@@ -380,6 +707,26 @@ class DynamixelMotorsBus:
|
||||
|
||||
return values
|
||||
|
||||
def _write_with_motor_ids(self, motor_models, motor_ids, data_name, values):
|
||||
if not isinstance(motor_ids, list):
|
||||
motor_ids = [motor_ids]
|
||||
if not isinstance(values, list):
|
||||
values = [values]
|
||||
|
||||
assert_same_address(self.model_ctrl_table, motor_models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
|
||||
group = GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes)
|
||||
for idx, value in zip(motor_ids, values, strict=True):
|
||||
data = convert_to_bytes(value, bytes)
|
||||
group.addParam(idx, data)
|
||||
|
||||
comm = group.txPacket()
|
||||
if comm != COMM_SUCCESS:
|
||||
raise ConnectionError(
|
||||
f"Write failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: "
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
@@ -406,7 +753,7 @@ class DynamixelMotorsBus:
|
||||
motor_ids.append(motor_idx)
|
||||
models.append(model)
|
||||
|
||||
if data_name in CALIBRATION_REQUIRED:
|
||||
if data_name in CALIBRATION_REQUIRED and self.calibration is not None:
|
||||
values = self.revert_calibration(values, motor_names)
|
||||
|
||||
values = values.tolist()
|
||||
@@ -422,30 +769,7 @@ class DynamixelMotorsBus:
|
||||
)
|
||||
|
||||
for idx, value in zip(motor_ids, values, strict=True):
|
||||
# Note: No need to convert back into unsigned int, since this byte preprocessing
|
||||
# already handles it for us.
|
||||
if bytes == 1:
|
||||
data = [
|
||||
DXL_LOBYTE(DXL_LOWORD(value)),
|
||||
]
|
||||
elif bytes == 2:
|
||||
data = [
|
||||
DXL_LOBYTE(DXL_LOWORD(value)),
|
||||
DXL_HIBYTE(DXL_LOWORD(value)),
|
||||
]
|
||||
elif bytes == 4:
|
||||
data = [
|
||||
DXL_LOBYTE(DXL_LOWORD(value)),
|
||||
DXL_HIBYTE(DXL_LOWORD(value)),
|
||||
DXL_LOBYTE(DXL_HIWORD(value)),
|
||||
DXL_HIBYTE(DXL_HIWORD(value)),
|
||||
]
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
|
||||
f"{bytes} is provided instead."
|
||||
)
|
||||
|
||||
data = convert_to_bytes(value, bytes)
|
||||
if init_group:
|
||||
self.group_writers[group_key].addParam(idx, data)
|
||||
else:
|
||||
|
||||
@@ -1,46 +1,7 @@
|
||||
def make_robot(name):
|
||||
if name == "koch":
|
||||
# TODO(rcadene): Add configurable robot from command line and yaml config
|
||||
# TODO(rcadene): Add example with and without cameras
|
||||
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
|
||||
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
|
||||
from lerobot.common.robot_devices.robots.koch import KochRobot
|
||||
import hydra
|
||||
from omegaconf import DictConfig
|
||||
|
||||
robot = KochRobot(
|
||||
leader_arms={
|
||||
"main": DynamixelMotorsBus(
|
||||
port="/dev/tty.usbmodem575E0031751",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": (1, "xl330-m077"),
|
||||
"shoulder_lift": (2, "xl330-m077"),
|
||||
"elbow_flex": (3, "xl330-m077"),
|
||||
"wrist_flex": (4, "xl330-m077"),
|
||||
"wrist_roll": (5, "xl330-m077"),
|
||||
"gripper": (6, "xl330-m077"),
|
||||
},
|
||||
),
|
||||
},
|
||||
follower_arms={
|
||||
"main": DynamixelMotorsBus(
|
||||
port="/dev/tty.usbmodem575E0032081",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": (1, "xl430-w250"),
|
||||
"shoulder_lift": (2, "xl430-w250"),
|
||||
"elbow_flex": (3, "xl330-m288"),
|
||||
"wrist_flex": (4, "xl330-m288"),
|
||||
"wrist_roll": (5, "xl330-m288"),
|
||||
"gripper": (6, "xl330-m288"),
|
||||
},
|
||||
),
|
||||
},
|
||||
cameras={
|
||||
"laptop": OpenCVCamera(0, fps=30, width=640, height=480),
|
||||
"phone": OpenCVCamera(1, fps=30, width=640, height=480),
|
||||
},
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Robot '{name}' not found.")
|
||||
|
||||
def make_robot(cfg: DictConfig):
|
||||
robot = hydra.utils.instantiate(cfg)
|
||||
return robot
|
||||
|
||||
@@ -8,122 +8,43 @@ import torch
|
||||
|
||||
from lerobot.common.robot_devices.cameras.utils import Camera
|
||||
from lerobot.common.robot_devices.motors.dynamixel import (
|
||||
DriveMode,
|
||||
DynamixelMotorsBus,
|
||||
OperatingMode,
|
||||
TorqueMode,
|
||||
convert_degrees_to_steps,
|
||||
)
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
|
||||
URL_HORIZONTAL_POSITION = {
|
||||
"follower": "https://raw.githubusercontent.com/huggingface/lerobot/main/media/koch/follower_horizontal.png",
|
||||
"leader": "https://raw.githubusercontent.com/huggingface/lerobot/main/media/koch/leader_horizontal.png",
|
||||
}
|
||||
URL_90_DEGREE_POSITION = {
|
||||
"follower": "https://raw.githubusercontent.com/huggingface/lerobot/main/media/koch/follower_90_degree.png",
|
||||
"leader": "https://raw.githubusercontent.com/huggingface/lerobot/main/media/koch/leader_90_degree.png",
|
||||
}
|
||||
|
||||
########################################################################
|
||||
# Calibration logic
|
||||
########################################################################
|
||||
|
||||
TARGET_HORIZONTAL_POSITION = np.array([0, -1024, 1024, 0, -1024, 0])
|
||||
TARGET_90_DEGREE_POSITION = np.array([1024, 0, 0, 1024, 0, -1024])
|
||||
GRIPPER_OPEN = np.array([-400])
|
||||
URL_TEMPLATE = (
|
||||
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
|
||||
)
|
||||
|
||||
# In nominal degree range ]-180, +180[
|
||||
ZERO_POSITION_DEGREE = 0
|
||||
ROTATED_POSITION_DEGREE = 90
|
||||
GRIPPER_OPEN_DEGREE = 35.156
|
||||
|
||||
|
||||
def apply_homing_offset(values: np.array, homing_offset: np.array) -> np.array:
|
||||
for i in range(len(values)):
|
||||
if values[i] is not None:
|
||||
values[i] += homing_offset[i]
|
||||
return values
|
||||
def assert_drive_mode(drive_mode):
|
||||
# `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
|
||||
if not np.all(np.isin(drive_mode, [0, 1])):
|
||||
raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})")
|
||||
|
||||
|
||||
def apply_drive_mode(values: np.array, drive_mode: np.array) -> np.array:
|
||||
for i in range(len(values)):
|
||||
if values[i] is not None and drive_mode[i]:
|
||||
values[i] = -values[i]
|
||||
return values
|
||||
def apply_drive_mode(position, drive_mode):
|
||||
assert_drive_mode(drive_mode)
|
||||
# Convert `drive_mode` from [0, 1] with 0 indicates original rotation direction and 1 inverted,
|
||||
# to [-1, 1] with 1 indicates original rotation direction and -1 inverted.
|
||||
signed_drive_mode = -(drive_mode * 2 - 1)
|
||||
position *= signed_drive_mode
|
||||
return position
|
||||
|
||||
|
||||
def apply_calibration(values: np.array, homing_offset: np.array, drive_mode: np.array) -> np.array:
|
||||
values = apply_drive_mode(values, drive_mode)
|
||||
values = apply_homing_offset(values, homing_offset)
|
||||
return values
|
||||
|
||||
|
||||
def revert_calibration(values: np.array, homing_offset: np.array, drive_mode: np.array) -> np.array:
|
||||
"""
|
||||
Transform working position into real position for the robot.
|
||||
"""
|
||||
values = apply_homing_offset(
|
||||
values,
|
||||
np.array([-homing_offset if homing_offset is not None else None for homing_offset in homing_offset]),
|
||||
)
|
||||
values = apply_drive_mode(values, drive_mode)
|
||||
return values
|
||||
|
||||
|
||||
def revert_appropriate_positions(positions: np.array, drive_mode: list[bool]) -> np.array:
|
||||
for i, revert in enumerate(drive_mode):
|
||||
if not revert and positions[i] is not None:
|
||||
positions[i] = -positions[i]
|
||||
return positions
|
||||
|
||||
|
||||
def compute_corrections(positions: np.array, drive_mode: list[bool], target_position: np.array) -> np.array:
|
||||
correction = revert_appropriate_positions(positions, drive_mode)
|
||||
|
||||
for i in range(len(positions)):
|
||||
if correction[i] is not None:
|
||||
if drive_mode[i]:
|
||||
correction[i] -= target_position[i]
|
||||
else:
|
||||
correction[i] += target_position[i]
|
||||
|
||||
return correction
|
||||
|
||||
|
||||
def compute_nearest_rounded_positions(positions: np.array) -> np.array:
|
||||
return np.array(
|
||||
[
|
||||
round(positions[i] / 1024) * 1024 if positions[i] is not None else None
|
||||
for i in range(len(positions))
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def compute_homing_offset(
|
||||
arm: DynamixelMotorsBus, drive_mode: list[bool], target_position: np.array
|
||||
) -> np.array:
|
||||
# Get the present positions of the servos
|
||||
present_positions = apply_calibration(
|
||||
arm.read("Present_Position"), np.array([0, 0, 0, 0, 0, 0]), drive_mode
|
||||
)
|
||||
|
||||
nearest_positions = compute_nearest_rounded_positions(present_positions)
|
||||
correction = compute_corrections(nearest_positions, drive_mode, target_position)
|
||||
return correction
|
||||
|
||||
|
||||
def compute_drive_mode(arm: DynamixelMotorsBus, offset: np.array):
|
||||
# Get current positions
|
||||
present_positions = apply_calibration(
|
||||
arm.read("Present_Position"), offset, np.array([False, False, False, False, False, False])
|
||||
)
|
||||
|
||||
nearest_positions = compute_nearest_rounded_positions(present_positions)
|
||||
|
||||
# construct 'drive_mode' list comparing nearest_positions and TARGET_90_DEGREE_POSITION
|
||||
drive_mode = []
|
||||
for i in range(len(nearest_positions)):
|
||||
drive_mode.append(nearest_positions[i] != TARGET_90_DEGREE_POSITION[i])
|
||||
return drive_mode
|
||||
|
||||
|
||||
def reset_arm(arm: MotorsBus):
|
||||
def reset_torque_mode(arm: MotorsBus):
|
||||
# To be configured, all servos must be in "torque disable" mode
|
||||
arm.write("Torque_Enable", TorqueMode.DISABLED.value)
|
||||
|
||||
@@ -132,55 +53,95 @@ def reset_arm(arm: MotorsBus):
|
||||
# you could end up with a servo with a position 0 or 4095 at a crucial point See [
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11]
|
||||
all_motors_except_gripper = [name for name in arm.motor_names if name != "gripper"]
|
||||
arm.write("Operating_Mode", OperatingMode.EXTENDED_POSITION.value, all_motors_except_gripper)
|
||||
if len(all_motors_except_gripper) > 0:
|
||||
arm.write("Operating_Mode", OperatingMode.EXTENDED_POSITION.value, all_motors_except_gripper)
|
||||
|
||||
# TODO(rcadene): why?
|
||||
# Use 'position control current based' for gripper
|
||||
# Use 'position control current based' for gripper to be limited by the limit of the current.
|
||||
# For the follower gripper, it means it can grasp an object without forcing too much even tho,
|
||||
# it's goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch).
|
||||
# For the leader gripper, it means we can use it as a physical trigger, since we can force with our finger
|
||||
# to make it move, and it will move back to its original target position when we release the force.
|
||||
arm.write("Operating_Mode", OperatingMode.CURRENT_CONTROLLED_POSITION.value, "gripper")
|
||||
|
||||
# Make sure the native calibration (homing offset abd drive mode) is disabled, since we use our own calibration layer to be more generic
|
||||
arm.write("Homing_Offset", 0)
|
||||
arm.write("Drive_Mode", DriveMode.NON_INVERTED.value)
|
||||
|
||||
|
||||
def run_arm_calibration(arm: MotorsBus, name: str, arm_type: str):
|
||||
"""Example of usage:
|
||||
"""This function ensures that a neural network trained on data collected on a given robot
|
||||
can work on another robot. For instance before calibration, setting a same goal position
|
||||
for each motor of two different robots will get two very different positions. But after calibration,
|
||||
the two robots will move to the same position.To this end, this function computes the homing offset
|
||||
and the drive mode for each motor of a given robot.
|
||||
|
||||
Homing offset is used to shift the motor position to a ]-2048, +2048[ nominal range (when the motor uses 2048 steps
|
||||
to complete a half a turn). This range is set around an arbitrary "zero position" corresponding to all motor positions
|
||||
being 0. During the calibration process, you will need to manually move the robot to this "zero position".
|
||||
|
||||
Drive mode is used to invert the rotation direction of the motor. This is useful when some motors have been assembled
|
||||
in the opposite orientation for some robots. During the calibration process, you will need to manually move the robot
|
||||
to the "rotated position".
|
||||
|
||||
After calibration, the homing offsets and drive modes are stored in a cache.
|
||||
|
||||
Example of usage:
|
||||
```python
|
||||
run_arm_calibration(arm, "left", "follower")
|
||||
```
|
||||
"""
|
||||
reset_arm(arm)
|
||||
reset_torque_mode(arm)
|
||||
|
||||
# TODO(rcadene): document what position 1 mean
|
||||
print(
|
||||
f"Please move the '{name} {arm_type}' arm to the horizontal position (gripper fully closed, see {URL_HORIZONTAL_POSITION[arm_type]})"
|
||||
)
|
||||
print(f"\nRunning calibration of {name} {arm_type}...")
|
||||
|
||||
print("\nMove arm to zero position")
|
||||
print("See: " + URL_TEMPLATE.format(robot="koch", arm=arm_type, position="zero"))
|
||||
input("Press Enter to continue...")
|
||||
|
||||
horizontal_homing_offset = compute_homing_offset(
|
||||
arm, [False, False, False, False, False, False], TARGET_HORIZONTAL_POSITION
|
||||
)
|
||||
# We arbitrarely choosed our zero target position to be a straight horizontal position with gripper upwards and closed.
|
||||
# It is easy to identify and all motors are in a "quarter turn" position. Once calibration is done, this position will
|
||||
# corresponds to every motor angle being 0. If you set all 0 as Goal Position, the arm will move in this position.
|
||||
zero_position = convert_degrees_to_steps(ZERO_POSITION_DEGREE, arm.motor_models)
|
||||
|
||||
# TODO(rcadene): document what position 2 mean
|
||||
print(
|
||||
f"Please move the '{name} {arm_type}' arm to the 90 degree position (gripper fully open, see {URL_90_DEGREE_POSITION[arm_type]})"
|
||||
)
|
||||
def _compute_nearest_rounded_position(position, models):
|
||||
# TODO(rcadene): Rework this function since some motors cant physically rotate a quarter turn
|
||||
# (e.g. the gripper of Aloha arms can only rotate ~50 degree)
|
||||
quarter_turn_degree = 90
|
||||
quarter_turn = convert_degrees_to_steps(quarter_turn_degree, models)
|
||||
nearest_pos = np.round(position.astype(float) / quarter_turn) * quarter_turn
|
||||
return nearest_pos.astype(position.dtype)
|
||||
|
||||
# Compute homing offset so that `present_position + homing_offset ~= target_position`.
|
||||
position = arm.read("Present_Position")
|
||||
position = _compute_nearest_rounded_position(position, arm.motor_models)
|
||||
homing_offset = zero_position - position
|
||||
|
||||
print("\nMove arm to rotated target position")
|
||||
print("See: " + URL_TEMPLATE.format(robot="koch", arm=arm_type, position="rotated"))
|
||||
input("Press Enter to continue...")
|
||||
|
||||
drive_mode = compute_drive_mode(arm, horizontal_homing_offset)
|
||||
homing_offset = compute_homing_offset(arm, drive_mode, TARGET_90_DEGREE_POSITION)
|
||||
# The rotated target position corresponds to a rotation of a quarter turn from the zero position.
|
||||
# This allows to identify the rotation direction of each motor.
|
||||
# For instance, if the motor rotates 90 degree, and its value is -90 after applying the homing offset, then we know its rotation direction
|
||||
# is inverted. However, for the calibration being successful, we need everyone to follow the same target position.
|
||||
# Sometimes, there is only one possible rotation direction. For instance, if the gripper is closed, there is only one direction which
|
||||
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarely rotate clockwise from the point of view
|
||||
# of the previous motor in the kinetic chain.
|
||||
rotated_position = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models)
|
||||
|
||||
# Invert offset for all drive_mode servos
|
||||
for i in range(len(drive_mode)):
|
||||
if drive_mode[i]:
|
||||
homing_offset[i] = -homing_offset[i]
|
||||
# Find drive mode by rotating each motor by a quarter of a turn.
|
||||
# Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
|
||||
position = arm.read("Present_Position")
|
||||
position += homing_offset
|
||||
position = _compute_nearest_rounded_position(position, arm.motor_models)
|
||||
drive_mode = (position != rotated_position).astype(np.int32)
|
||||
|
||||
print("Calibration is done!")
|
||||
# Re-compute homing offset to take into account drive mode
|
||||
position = arm.read("Present_Position")
|
||||
position = apply_drive_mode(position, drive_mode)
|
||||
position = _compute_nearest_rounded_position(position, arm.motor_models)
|
||||
homing_offset = rotated_position - position
|
||||
|
||||
print("=====================================")
|
||||
print(" HOMING_OFFSET: ", " ".join([str(i) for i in homing_offset]))
|
||||
print(" DRIVE_MODE: ", " ".join([str(i) for i in drive_mode]))
|
||||
print("=====================================")
|
||||
print("\nMove arm to rest position")
|
||||
print("See: " + URL_TEMPLATE.format(robot="koch", arm=arm_type, position="rest"))
|
||||
input("Press Enter to continue...")
|
||||
print()
|
||||
|
||||
return homing_offset, drive_mode
|
||||
|
||||
@@ -207,7 +168,12 @@ class KochRobotConfig:
|
||||
|
||||
class KochRobot:
|
||||
# TODO(rcadene): Implement force feedback
|
||||
"""Tau Robotics: https://tau-robotics.com
|
||||
"""This class allows to control any Koch robot of various number of motors.
|
||||
|
||||
A few versions are available:
|
||||
- [Koch v1.0](https://github.com/AlexanderKoch-Koch/low_cost_robot), with and without the wrist-to-elbow expansion, which was developed
|
||||
by Alexander Koch from [Tau Robotics](https://tau-robotics.com): [Github for sourcing and assembly](
|
||||
- [Koch v1.1])https://github.com/jess-moss/koch-v1-1), which was developed by Jess Moss.
|
||||
|
||||
Example of highest frequency teleoperation without camera:
|
||||
```python
|
||||
@@ -261,12 +227,12 @@ class KochRobot:
|
||||
Example of highest frequency data collection with cameras:
|
||||
```python
|
||||
# Defines how to communicate with 2 cameras connected to the computer.
|
||||
# Here, the webcam of the mackbookpro and the iphone (connected in USB to the macbookpro)
|
||||
# Here, the webcam of the laptop and the phone (connected in USB to the laptop)
|
||||
# can be reached respectively using the camera indices 0 and 1. These indices can be
|
||||
# arbitrary. See the documentation of `OpenCVCamera` to find your own camera indices.
|
||||
cameras = {
|
||||
"macbookpro": OpenCVCamera(camera_index=0, fps=30, width=640, height=480),
|
||||
"iphone": OpenCVCamera(camera_index=1, fps=30, width=640, height=480),
|
||||
"laptop": OpenCVCamera(camera_index=0, fps=30, width=640, height=480),
|
||||
"phone": OpenCVCamera(camera_index=1, fps=30, width=640, height=480),
|
||||
}
|
||||
|
||||
# Assumes leader and follower arms have been instantiated already (see first example)
|
||||
@@ -330,23 +296,27 @@ class KochRobot:
|
||||
|
||||
# Connect the arms
|
||||
for name in self.follower_arms:
|
||||
print(f"Connecting {name} follower arm.")
|
||||
self.follower_arms[name].connect()
|
||||
print(f"Connecting {name} leader arm.")
|
||||
self.leader_arms[name].connect()
|
||||
|
||||
# Reset the arms and load or run calibration
|
||||
if self.calibration_path.exists():
|
||||
# Reset all arms before setting calibration
|
||||
for name in self.follower_arms:
|
||||
reset_arm(self.follower_arms[name])
|
||||
reset_torque_mode(self.follower_arms[name])
|
||||
for name in self.leader_arms:
|
||||
reset_arm(self.leader_arms[name])
|
||||
reset_torque_mode(self.leader_arms[name])
|
||||
|
||||
with open(self.calibration_path, "rb") as f:
|
||||
calibration = pickle.load(f)
|
||||
else:
|
||||
print(f"Missing calibration file '{self.calibration_path}'. Starting calibration precedure.")
|
||||
# Run calibration process which begins by reseting all arms
|
||||
calibration = self.run_calibration()
|
||||
|
||||
print(f"Calibration is done! Saving calibration file '{self.calibration_path}'")
|
||||
self.calibration_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(self.calibration_path, "wb") as f:
|
||||
pickle.dump(calibration, f)
|
||||
@@ -366,13 +336,14 @@ class KochRobot:
|
||||
|
||||
# Enable torque on all motors of the follower arms
|
||||
for name in self.follower_arms:
|
||||
print(f"Activating torque on {name} follower arm.")
|
||||
self.follower_arms[name].write("Torque_Enable", 1)
|
||||
|
||||
# Enable torque on the gripper of the leader arms, and move it to 45 degrees,
|
||||
# so that we can use it as a trigger to close the gripper of the follower arms.
|
||||
for name in self.leader_arms:
|
||||
self.leader_arms[name].write("Torque_Enable", 1, "gripper")
|
||||
self.leader_arms[name].write("Goal_Position", GRIPPER_OPEN, "gripper")
|
||||
self.leader_arms[name].write("Goal_Position", GRIPPER_OPEN_DEGREE, "gripper")
|
||||
|
||||
# Connect the cameras
|
||||
for name in self.cameras:
|
||||
@@ -407,12 +378,12 @@ class KochRobot:
|
||||
"KochRobot is not connected. You need to run `robot.connect()`."
|
||||
)
|
||||
|
||||
# Prepare to assign the positions of the leader to the follower
|
||||
# Prepare to assign the position of the leader to the follower
|
||||
leader_pos = {}
|
||||
for name in self.leader_arms:
|
||||
now = time.perf_counter()
|
||||
before_lread_t = time.perf_counter()
|
||||
leader_pos[name] = self.leader_arms[name].read("Present_Position")
|
||||
self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - now
|
||||
self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - before_lread_t
|
||||
|
||||
follower_goal_pos = {}
|
||||
for name in self.leader_arms:
|
||||
@@ -420,9 +391,9 @@ class KochRobot:
|
||||
|
||||
# Send action
|
||||
for name in self.follower_arms:
|
||||
now = time.perf_counter()
|
||||
before_fwrite_t = time.perf_counter()
|
||||
self.follower_arms[name].write("Goal_Position", follower_goal_pos[name])
|
||||
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - now
|
||||
self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t
|
||||
|
||||
# Early exit when recording data is not requested
|
||||
if not record_data:
|
||||
@@ -432,9 +403,9 @@ class KochRobot:
|
||||
# Read follower position
|
||||
follower_pos = {}
|
||||
for name in self.follower_arms:
|
||||
now = time.perf_counter()
|
||||
before_fread_t = time.perf_counter()
|
||||
follower_pos[name] = self.follower_arms[name].read("Present_Position")
|
||||
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - now
|
||||
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
|
||||
|
||||
# Create state by concatenating follower current position
|
||||
state = []
|
||||
@@ -453,10 +424,10 @@ class KochRobot:
|
||||
# Capture images from cameras
|
||||
images = {}
|
||||
for name in self.cameras:
|
||||
now = time.perf_counter()
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - now
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
|
||||
# Populate output dictionnaries and format to pytorch
|
||||
obs_dict, action_dict = {}, {}
|
||||
@@ -477,9 +448,9 @@ class KochRobot:
|
||||
# Read follower position
|
||||
follower_pos = {}
|
||||
for name in self.follower_arms:
|
||||
now = time.perf_counter()
|
||||
before_fread_t = time.perf_counter()
|
||||
follower_pos[name] = self.follower_arms[name].read("Present_Position")
|
||||
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - now
|
||||
self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t
|
||||
|
||||
# Create state by concatenating follower current position
|
||||
state = []
|
||||
@@ -491,20 +462,16 @@ class KochRobot:
|
||||
# Capture images from cameras
|
||||
images = {}
|
||||
for name in self.cameras:
|
||||
now = time.perf_counter()
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - now
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
|
||||
# Populate output dictionnaries and format to pytorch
|
||||
obs_dict = {}
|
||||
obs_dict["observation.state"] = torch.from_numpy(state)
|
||||
for name in self.cameras:
|
||||
# Convert to pytorch format: channel first and float32 in [0,1]
|
||||
img = torch.from_numpy(images[name])
|
||||
img = img.type(torch.float32) / 255
|
||||
img = img.permute(2, 0, 1).contiguous()
|
||||
obs_dict[f"observation.images.{name}"] = img
|
||||
obs_dict[f"observation.images.{name}"] = torch.from_numpy(images[name])
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: torch.Tensor):
|
||||
|
||||
@@ -158,6 +158,7 @@ def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> D
|
||||
version_base="1.2",
|
||||
)
|
||||
cfg = hydra.compose(Path(config_path).stem, overrides)
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ training:
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
|
||||
@@ -50,7 +50,7 @@ training:
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
|
||||
@@ -48,7 +48,7 @@ training:
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
|
||||
39
lerobot/configs/robot/koch.yaml
Normal file
@@ -0,0 +1,39 @@
|
||||
_target_: lerobot.common.robot_devices.robots.koch.KochRobot
|
||||
calibration_path: .cache/calibration/koch.pkl
|
||||
leader_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/tty.usbmodem575E0031751
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "xl330-m077"]
|
||||
shoulder_lift: [2, "xl330-m077"]
|
||||
elbow_flex: [3, "xl330-m077"]
|
||||
wrist_flex: [4, "xl330-m077"]
|
||||
wrist_roll: [5, "xl330-m077"]
|
||||
gripper: [6, "xl330-m077"]
|
||||
follower_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/tty.usbmodem575E0032081
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "xl430-w250"]
|
||||
shoulder_lift: [2, "xl430-w250"]
|
||||
elbow_flex: [3, "xl330-m288"]
|
||||
wrist_flex: [4, "xl330-m288"]
|
||||
wrist_roll: [5, "xl330-m288"]
|
||||
gripper: [6, "xl330-m288"]
|
||||
cameras:
|
||||
laptop:
|
||||
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||
camera_index: 0
|
||||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
phone:
|
||||
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||
camera_index: 1
|
||||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
@@ -1,9 +1,22 @@
|
||||
"""
|
||||
Utilities to control a robot.
|
||||
|
||||
Useful to record a dataset, replay a recorded episode, run the policy on your robot
|
||||
and record an evaluation dataset, and to recalibrate your robot if needed.
|
||||
|
||||
Examples of usage:
|
||||
|
||||
- Recalibrate your robot:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py calibrate
|
||||
```
|
||||
|
||||
- Unlimited teleoperation at highest frequency (~200 Hz is expected), to exit with CTRL+C:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate
|
||||
|
||||
# Remove the cameras from the robot definition. They are not used in 'teleoperate' anyway.
|
||||
python lerobot/scripts/control_robot.py teleoperate --robot-overrides '~cameras'
|
||||
```
|
||||
|
||||
- Unlimited teleoperation at a limited frequency of 30 Hz, to simulate data recording frequency:
|
||||
@@ -14,7 +27,7 @@ python lerobot/scripts/control_robot.py teleoperate \
|
||||
|
||||
- Record one episode in order to test replay:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record_dataset \
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--fps 30 \
|
||||
--root tmp/data \
|
||||
--repo-id $USER/koch_test \
|
||||
@@ -32,7 +45,7 @@ python lerobot/scripts/visualize_dataset.py \
|
||||
|
||||
- Replay this test episode:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py replay_episode \
|
||||
python lerobot/scripts/control_robot.py replay \
|
||||
--fps 30 \
|
||||
--root tmp/data \
|
||||
--repo-id $USER/koch_test \
|
||||
@@ -42,12 +55,11 @@ python lerobot/scripts/control_robot.py replay_episode \
|
||||
- Record a full dataset in order to train a policy, with 2 seconds of warmup,
|
||||
30 seconds of recording for each episode, and 10 seconds to reset the environment in between episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record_dataset \
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id $USER/koch_pick_place_lego \
|
||||
--num-episodes 50 \
|
||||
--run-compute-stats 1 \
|
||||
--warmup-time-s 2 \
|
||||
--episode-time-s 30 \
|
||||
--reset-time-s 10
|
||||
@@ -74,7 +86,14 @@ DATA_DIR=data python lerobot/scripts/train.py \
|
||||
|
||||
- Run the pretrained policy on the robot:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py run_policy \
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id $USER/eval_act_koch_real \
|
||||
--num-episodes 10 \
|
||||
--warmup-time-s 2 \
|
||||
--episode-time-s 30 \
|
||||
--reset-time-s 10
|
||||
-p outputs/train/act_koch_real/checkpoints/080000/pretrained_model
|
||||
```
|
||||
"""
|
||||
@@ -87,12 +106,14 @@ import os
|
||||
import platform
|
||||
import shutil
|
||||
import time
|
||||
import traceback
|
||||
from contextlib import nullcontext
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import tqdm
|
||||
from huggingface_hub import create_branch
|
||||
from omegaconf import DictConfig
|
||||
from PIL import Image
|
||||
from termcolor import colored
|
||||
@@ -102,20 +123,45 @@ from lerobot.common.datasets.compute_stats import compute_stats
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, get_default_encoding
|
||||
from lerobot.common.datasets.utils import calculate_episode_data_index
|
||||
from lerobot.common.datasets.utils import calculate_episode_data_index, create_branch
|
||||
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
||||
from lerobot.scripts.eval import get_pretrained_policy_path
|
||||
from lerobot.scripts.push_dataset_to_hub import push_meta_data_to_hub, push_videos_to_hub, save_meta_data
|
||||
from lerobot.scripts.push_dataset_to_hub import (
|
||||
push_dataset_card_to_hub,
|
||||
push_meta_data_to_hub,
|
||||
push_videos_to_hub,
|
||||
save_meta_data,
|
||||
)
|
||||
|
||||
########################################################################################
|
||||
# Utilities
|
||||
########################################################################################
|
||||
|
||||
|
||||
def say(text, blocking=False):
|
||||
# Check if mac, linux, or windows.
|
||||
if platform.system() == "Darwin":
|
||||
cmd = f'say "{text}"'
|
||||
elif platform.system() == "Linux":
|
||||
cmd = f'spd-say "{text}"'
|
||||
elif platform.system() == "Windows":
|
||||
cmd = (
|
||||
'PowerShell -Command "Add-Type -AssemblyName System.Speech; '
|
||||
f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')\""
|
||||
)
|
||||
|
||||
if not blocking and platform.system() in ["Darwin", "Linux"]:
|
||||
# TODO(rcadene): Make it work for Windows
|
||||
# Use the ampersand to run command in the background
|
||||
cmd += " &"
|
||||
|
||||
os.system(cmd)
|
||||
|
||||
|
||||
def save_image(img_tensor, key, frame_index, episode_index, videos_dir):
|
||||
img = Image.fromarray(img_tensor.numpy())
|
||||
path = videos_dir / f"{key}_episode_{episode_index:06d}" / f"frame_{frame_index:06d}.png"
|
||||
@@ -160,11 +206,11 @@ def log_control_info(robot, dt_s, episode_index=None, frame_index=None, fps=None
|
||||
for name in robot.follower_arms:
|
||||
key = f"write_follower_{name}_goal_pos_dt_s"
|
||||
if key in robot.logs:
|
||||
log_dt("dtRfoll", robot.logs[key])
|
||||
log_dt("dtWfoll", robot.logs[key])
|
||||
|
||||
key = f"read_follower_{name}_pos_dt_s"
|
||||
if key in robot.logs:
|
||||
log_dt("dtWfoll", robot.logs[key])
|
||||
log_dt("dtRfoll", robot.logs[key])
|
||||
|
||||
for name in robot.cameras:
|
||||
key = f"read_camera_{name}_dt_s"
|
||||
@@ -179,12 +225,23 @@ def log_control_info(robot, dt_s, episode_index=None, frame_index=None, fps=None
|
||||
logging.info(info_str)
|
||||
|
||||
|
||||
def get_is_headless():
|
||||
if platform.system() == "Linux":
|
||||
display = os.environ.get("DISPLAY")
|
||||
if display is None or display == "":
|
||||
return True
|
||||
return False
|
||||
@cache
|
||||
def is_headless():
|
||||
"""Detects if python is running without a monitor."""
|
||||
try:
|
||||
import pynput # noqa
|
||||
|
||||
return False
|
||||
except Exception:
|
||||
print(
|
||||
"Error trying to import pynput. Switching to headless mode. "
|
||||
"As a result, the video stream from the cameras won't be shown, "
|
||||
"and you won't be able to change the control flow with keyboards. "
|
||||
"For more info, see traceback below.\n"
|
||||
)
|
||||
traceback.print_exc()
|
||||
print()
|
||||
return True
|
||||
|
||||
|
||||
########################################################################################
|
||||
@@ -192,29 +249,44 @@ def get_is_headless():
|
||||
########################################################################################
|
||||
|
||||
|
||||
def calibrate(robot: Robot):
|
||||
if robot.calibration_path.exists():
|
||||
print(f"Removing '{robot.calibration_path}'")
|
||||
robot.calibration_path.unlink()
|
||||
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
|
||||
# Calling `connect` automatically runs calibration
|
||||
# when the calibration file is missing
|
||||
robot.connect()
|
||||
|
||||
|
||||
def teleoperate(robot: Robot, fps: int | None = None, teleop_time_s: float | None = None):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
start_time = time.perf_counter()
|
||||
start_teleop_t = time.perf_counter()
|
||||
while True:
|
||||
now = time.perf_counter()
|
||||
start_loop_t = time.perf_counter()
|
||||
robot.teleop_step()
|
||||
|
||||
if fps is not None:
|
||||
dt_s = time.perf_counter() - now
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - now
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
if teleop_time_s is not None and time.perf_counter() - start_time > teleop_time_s:
|
||||
if teleop_time_s is not None and time.perf_counter() - start_teleop_t > teleop_time_s:
|
||||
break
|
||||
|
||||
|
||||
def record_dataset(
|
||||
def record(
|
||||
robot: Robot,
|
||||
policy: torch.nn.Module | None = None,
|
||||
hydra_cfg: DictConfig | None = None,
|
||||
fps: int | None = None,
|
||||
root="data",
|
||||
repo_id="lerobot/debug",
|
||||
@@ -225,10 +297,18 @@ def record_dataset(
|
||||
video=True,
|
||||
run_compute_stats=True,
|
||||
push_to_hub=True,
|
||||
tags=None,
|
||||
num_image_writers=8,
|
||||
force_override=False,
|
||||
):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
# TODO(rcadene): Clean this function via decomposition in higher level functions
|
||||
|
||||
_, dataset_name = repo_id.split("/")
|
||||
if dataset_name.startswith("eval_") and policy is None:
|
||||
raise ValueError(
|
||||
f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})."
|
||||
)
|
||||
|
||||
if not video:
|
||||
raise NotImplementedError()
|
||||
@@ -255,32 +335,10 @@ def record_dataset(
|
||||
else:
|
||||
episode_index = 0
|
||||
|
||||
is_headless = get_is_headless()
|
||||
|
||||
# Execute a few seconds without recording data, to give times
|
||||
# to the robot devices to connect and start synchronizing.
|
||||
timestamp = 0
|
||||
start_time = time.perf_counter()
|
||||
is_warmup_print = False
|
||||
while timestamp < warmup_time_s:
|
||||
if not is_warmup_print:
|
||||
logging.info("Warming up (no data recording)")
|
||||
os.system('say "Warmup" &')
|
||||
is_warmup_print = True
|
||||
|
||||
now = time.perf_counter()
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
|
||||
if not is_headless:
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
|
||||
dt_s = time.perf_counter() - now
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - now
|
||||
log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
timestamp = time.perf_counter() - start_time
|
||||
if is_headless():
|
||||
logging.info(
|
||||
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||||
)
|
||||
|
||||
# Allow to exit early while recording an episode or resetting the environment,
|
||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||
@@ -290,9 +348,7 @@ def record_dataset(
|
||||
stop_recording = False
|
||||
|
||||
# Only import pynput if not in a headless environment
|
||||
if is_headless:
|
||||
logging.info("Headless environment detected. Keyboard input will not be available.")
|
||||
else:
|
||||
if not is_headless():
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
@@ -315,6 +371,53 @@ def record_dataset(
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
|
||||
# Load policy if any
|
||||
if policy is not None:
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(hydra_cfg.device, log=True)
|
||||
|
||||
policy.eval()
|
||||
policy.to(device)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
set_global_seed(hydra_cfg.seed)
|
||||
|
||||
# override fps using policy fps
|
||||
fps = hydra_cfg.env.fps
|
||||
|
||||
# Execute a few seconds without recording data, to give times
|
||||
# to the robot devices to connect and start synchronizing.
|
||||
timestamp = 0
|
||||
start_warmup_t = time.perf_counter()
|
||||
is_warmup_print = False
|
||||
while timestamp < warmup_time_s:
|
||||
if not is_warmup_print:
|
||||
logging.info("Warming up (no data recording)")
|
||||
say("Warming up")
|
||||
is_warmup_print = True
|
||||
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
if policy is None:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
else:
|
||||
observation = robot.capture_observation()
|
||||
|
||||
if not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
timestamp = time.perf_counter() - start_warmup_t
|
||||
|
||||
# Save images using threads to reach high fps (30 and more)
|
||||
# Using `with` to exist smoothly if an execption is raised.
|
||||
# Using only 4 worker threads to avoid blocking the main thread.
|
||||
@@ -323,14 +426,18 @@ def record_dataset(
|
||||
# Start recording all episodes
|
||||
while episode_index < num_episodes:
|
||||
logging.info(f"Recording episode {episode_index}")
|
||||
os.system(f'say "Recording episode {episode_index}" &')
|
||||
say(f"Recording episode {episode_index}")
|
||||
ep_dict = {}
|
||||
frame_index = 0
|
||||
timestamp = 0
|
||||
start_time = time.perf_counter()
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < episode_time_s:
|
||||
now = time.perf_counter()
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
if policy is None:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
else:
|
||||
observation = robot.capture_observation()
|
||||
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
not_image_keys = [key for key in observation if "image" not in key]
|
||||
@@ -342,11 +449,46 @@ def record_dataset(
|
||||
)
|
||||
]
|
||||
|
||||
if not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
|
||||
for key in not_image_keys:
|
||||
if key not in ep_dict:
|
||||
ep_dict[key] = []
|
||||
ep_dict[key].append(observation[key])
|
||||
|
||||
if policy is not None:
|
||||
with (
|
||||
torch.inference_mode(),
|
||||
torch.autocast(device_type=device.type)
|
||||
if device.type == "cuda" and hydra_cfg.use_amp
|
||||
else nullcontext(),
|
||||
):
|
||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||
for name in observation:
|
||||
if "image" in name:
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||
observation[name] = observation[name].unsqueeze(0)
|
||||
observation[name] = observation[name].to(device)
|
||||
|
||||
# Compute the next action with the policy
|
||||
# based on the current observation
|
||||
action = policy.select_action(observation)
|
||||
|
||||
# Remove batch dimension
|
||||
action = action.squeeze(0)
|
||||
|
||||
# Move to cpu, if not already the case
|
||||
action = action.to("cpu")
|
||||
|
||||
# Order the robot to move
|
||||
robot.send_action(action)
|
||||
action = {"action": action}
|
||||
|
||||
for key in action:
|
||||
if key not in ep_dict:
|
||||
ep_dict[key] = []
|
||||
@@ -354,14 +496,13 @@ def record_dataset(
|
||||
|
||||
frame_index += 1
|
||||
|
||||
dt_s = time.perf_counter() - now
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - now
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
timestamp = time.perf_counter() - start_time
|
||||
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
if exit_early:
|
||||
exit_early = False
|
||||
break
|
||||
@@ -369,10 +510,10 @@ def record_dataset(
|
||||
if not stop_recording:
|
||||
# Start resetting env while the executor are finishing
|
||||
logging.info("Reset the environment")
|
||||
os.system('say "Reset the environment" &')
|
||||
say("Reset the environment")
|
||||
|
||||
timestamp = 0
|
||||
start_time = time.perf_counter()
|
||||
start_vencod_t = time.perf_counter()
|
||||
|
||||
# During env reset we save the data and encode the videos
|
||||
num_frames = frame_index
|
||||
@@ -418,7 +559,7 @@ def record_dataset(
|
||||
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
|
||||
while timestamp < reset_time_s and not is_last_episode:
|
||||
time.sleep(1)
|
||||
timestamp = time.perf_counter() - start_time
|
||||
timestamp = time.perf_counter() - start_vencod_t
|
||||
pbar.update(1)
|
||||
if exit_early:
|
||||
exit_early = False
|
||||
@@ -433,8 +574,8 @@ def record_dataset(
|
||||
|
||||
if is_last_episode:
|
||||
logging.info("Done recording")
|
||||
os.system('say "Done recording"')
|
||||
if not is_headless:
|
||||
say("Done recording", blocking=True)
|
||||
if not is_headless():
|
||||
listener.stop()
|
||||
|
||||
logging.info("Waiting for threads writing the images on disk to terminate...")
|
||||
@@ -444,10 +585,14 @@ def record_dataset(
|
||||
pass
|
||||
break
|
||||
|
||||
robot.disconnect()
|
||||
if not is_headless():
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
num_episodes = episode_index
|
||||
|
||||
logging.info("Encoding videos")
|
||||
os.system('say "Encoding videos" &')
|
||||
say("Encoding videos")
|
||||
# Use ffmpeg to convert frames stored as png into mp4 videos
|
||||
for episode_index in tqdm.tqdm(range(num_episodes)):
|
||||
for key in image_keys:
|
||||
@@ -455,6 +600,7 @@ def record_dataset(
|
||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||
video_path = local_dir / "videos" / fname
|
||||
if video_path.exists():
|
||||
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||
continue
|
||||
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
# since video encoding with ffmpeg is already using multithreading.
|
||||
@@ -491,11 +637,12 @@ def record_dataset(
|
||||
)
|
||||
if run_compute_stats:
|
||||
logging.info("Computing dataset statistics")
|
||||
os.system('say "Computing dataset statistics" &')
|
||||
say("Computing dataset statistics")
|
||||
stats = compute_stats(lerobot_dataset)
|
||||
lerobot_dataset.stats = stats
|
||||
else:
|
||||
logging.info("Skipping computation of the dataset statistrics")
|
||||
stats = {}
|
||||
logging.info("Skipping computation of the dataset statistics")
|
||||
|
||||
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
||||
hf_dataset.save_to_disk(str(local_dir / "train"))
|
||||
@@ -506,17 +653,17 @@ def record_dataset(
|
||||
if push_to_hub:
|
||||
hf_dataset.push_to_hub(repo_id, revision="main")
|
||||
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
|
||||
push_dataset_card_to_hub(repo_id, revision="main", tags=tags)
|
||||
if video:
|
||||
push_videos_to_hub(repo_id, videos_dir, revision="main")
|
||||
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
|
||||
|
||||
logging.info("Exiting")
|
||||
os.system('say "Exiting" &')
|
||||
|
||||
say("Exiting")
|
||||
return lerobot_dataset
|
||||
|
||||
|
||||
def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug"):
|
||||
def replay(robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug"):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
local_dir = Path(root) / repo_id
|
||||
if not local_dir.exists():
|
||||
@@ -531,76 +678,20 @@ def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="dat
|
||||
robot.connect()
|
||||
|
||||
logging.info("Replaying episode")
|
||||
os.system('say "Replaying episode"')
|
||||
|
||||
say("Replaying episode", blocking=True)
|
||||
for idx in range(from_idx, to_idx):
|
||||
now = time.perf_counter()
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action = items[idx]["action"]
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = time.perf_counter() - now
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - now
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
|
||||
def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run_time_s: float | None = None):
|
||||
# TODO(rcadene): Add option to record eval dataset and logs
|
||||
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(hydra_cfg.device, log=True)
|
||||
|
||||
policy.eval()
|
||||
policy.to(device)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
set_global_seed(hydra_cfg.seed)
|
||||
|
||||
fps = hydra_cfg.env.fps
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
start_time = time.perf_counter()
|
||||
while True:
|
||||
now = time.perf_counter()
|
||||
|
||||
observation = robot.capture_observation()
|
||||
|
||||
with (
|
||||
torch.inference_mode(),
|
||||
torch.autocast(device_type=device.type)
|
||||
if device.type == "cuda" and hydra_cfg.use_amp
|
||||
else nullcontext(),
|
||||
):
|
||||
# add batch dimension to 1
|
||||
for name in observation:
|
||||
observation[name] = observation[name].unsqueeze(0)
|
||||
|
||||
if device.type == "mps":
|
||||
for name in observation:
|
||||
observation[name] = observation[name].to(device)
|
||||
|
||||
action = policy.select_action(observation)
|
||||
|
||||
# remove batch dimension
|
||||
action = action.squeeze(0)
|
||||
|
||||
robot.send_action(action.to("cpu"))
|
||||
|
||||
dt_s = time.perf_counter() - now
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - now
|
||||
log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
if run_time_s is not None and time.perf_counter() - start_time > run_time_s:
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
subparsers = parser.add_subparsers(dest="mode", required=True)
|
||||
@@ -608,18 +699,26 @@ if __name__ == "__main__":
|
||||
# Set common options for all the subparsers
|
||||
base_parser = argparse.ArgumentParser(add_help=False)
|
||||
base_parser.add_argument(
|
||||
"--robot",
|
||||
"--robot-path",
|
||||
type=str,
|
||||
default="koch",
|
||||
help="Name of the robot provided to the `make_robot(name)` factory function.",
|
||||
default="lerobot/configs/robot/koch.yaml",
|
||||
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
|
||||
)
|
||||
base_parser.add_argument(
|
||||
"--robot-overrides",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
|
||||
parser_calib = subparsers.add_parser("calibrate", parents=[base_parser])
|
||||
|
||||
parser_teleop = subparsers.add_parser("teleoperate", parents=[base_parser])
|
||||
parser_teleop.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
)
|
||||
|
||||
parser_record = subparsers.add_parser("record_dataset", parents=[base_parser])
|
||||
parser_record = subparsers.add_parser("record", parents=[base_parser])
|
||||
parser_record.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
)
|
||||
@@ -638,19 +737,19 @@ if __name__ == "__main__":
|
||||
parser_record.add_argument(
|
||||
"--warmup-time-s",
|
||||
type=int,
|
||||
default=2,
|
||||
default=10,
|
||||
help="Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--episode-time-s",
|
||||
type=int,
|
||||
default=10,
|
||||
default=60,
|
||||
help="Number of seconds for data recording for each episode.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--reset-time-s",
|
||||
type=int,
|
||||
default=5,
|
||||
default=60,
|
||||
help="Number of seconds for resetting the environment after each episode.",
|
||||
)
|
||||
parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.")
|
||||
@@ -666,6 +765,12 @@ if __name__ == "__main__":
|
||||
default=1,
|
||||
help="Upload dataset to Hugging Face hub.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--tags",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Add tags to your dataset on the hub.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--num-image-writers",
|
||||
type=int,
|
||||
@@ -678,8 +783,23 @@ if __name__ == "__main__":
|
||||
default=0,
|
||||
help="By default, data recording is resumed. When set to 1, delete the local directory and start data recording from scratch.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"-p",
|
||||
"--pretrained-policy-name-or-path",
|
||||
type=str,
|
||||
help=(
|
||||
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
||||
"saved using `Policy.save_pretrained`."
|
||||
),
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--policy-overrides",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
|
||||
parser_replay = subparsers.add_parser("replay_episode", parents=[base_parser])
|
||||
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
|
||||
parser_replay.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
)
|
||||
@@ -697,41 +817,46 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episode to replay.")
|
||||
|
||||
parser_policy = subparsers.add_parser("run_policy", parents=[base_parser])
|
||||
parser_policy.add_argument(
|
||||
"-p",
|
||||
"--pretrained-policy-name-or-path",
|
||||
type=str,
|
||||
help=(
|
||||
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
||||
"saved using `Policy.save_pretrained`."
|
||||
),
|
||||
)
|
||||
parser_policy.add_argument(
|
||||
"overrides",
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
init_logging()
|
||||
|
||||
control_mode = args.mode
|
||||
robot_name = args.robot
|
||||
robot_path = args.robot_path
|
||||
robot_overrides = args.robot_overrides
|
||||
kwargs = vars(args)
|
||||
del kwargs["mode"]
|
||||
del kwargs["robot"]
|
||||
del kwargs["robot_path"]
|
||||
del kwargs["robot_overrides"]
|
||||
|
||||
robot = make_robot(robot_name)
|
||||
if control_mode == "teleoperate":
|
||||
robot_cfg = init_hydra_config(robot_path, robot_overrides)
|
||||
robot = make_robot(robot_cfg)
|
||||
|
||||
if control_mode == "calibrate":
|
||||
calibrate(robot, **kwargs)
|
||||
|
||||
elif control_mode == "teleoperate":
|
||||
teleoperate(robot, **kwargs)
|
||||
elif control_mode == "record_dataset":
|
||||
record_dataset(robot, **kwargs)
|
||||
elif control_mode == "replay_episode":
|
||||
replay_episode(robot, **kwargs)
|
||||
|
||||
elif control_mode == "run_policy":
|
||||
pretrained_policy_path = get_pretrained_policy_path(args.pretrained_policy_name_or_path)
|
||||
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", args.overrides)
|
||||
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
|
||||
run_policy(robot, policy, hydra_cfg)
|
||||
elif control_mode == "record":
|
||||
pretrained_policy_name_or_path = args.pretrained_policy_name_or_path
|
||||
policy_overrides = args.policy_overrides
|
||||
del kwargs["pretrained_policy_name_or_path"]
|
||||
del kwargs["policy_overrides"]
|
||||
|
||||
policy_cfg = None
|
||||
if pretrained_policy_name_or_path is not None:
|
||||
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
|
||||
policy_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides)
|
||||
policy = make_policy(hydra_cfg=policy_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
|
||||
record(robot, policy, policy_cfg, **kwargs)
|
||||
else:
|
||||
record(robot, **kwargs)
|
||||
|
||||
elif control_mode == "replay":
|
||||
replay(robot, **kwargs)
|
||||
|
||||
if robot.is_connected:
|
||||
# Disconnect manually to avoid a "Core dump" during process
|
||||
# termination due to camera threads not properly exiting.
|
||||
robot.disconnect()
|
||||
|
||||
@@ -56,6 +56,9 @@ import einops
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||
from huggingface_hub.utils._validators import HFValidationError
|
||||
from torch import Tensor, nn
|
||||
from tqdm import trange
|
||||
|
||||
@@ -65,7 +68,7 @@ from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.logger import log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, get_pretrained_policy_path
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.io_utils import write_video
|
||||
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
||||
|
||||
@@ -498,6 +501,29 @@ def main(
|
||||
logging.info("End of eval")
|
||||
|
||||
|
||||
def get_pretrained_policy_path(pretrained_policy_name_or_path, revision=None):
|
||||
try:
|
||||
pretrained_policy_path = Path(snapshot_download(pretrained_policy_name_or_path, revision=revision))
|
||||
except (HFValidationError, RepositoryNotFoundError) as e:
|
||||
if isinstance(e, HFValidationError):
|
||||
error_message = (
|
||||
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
|
||||
)
|
||||
else:
|
||||
error_message = (
|
||||
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
|
||||
)
|
||||
|
||||
logging.warning(f"{error_message} Treating it as a local directory.")
|
||||
pretrained_policy_path = Path(pretrained_policy_name_or_path)
|
||||
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
|
||||
raise ValueError(
|
||||
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
|
||||
"repo ID, nor is it an existing local directory."
|
||||
)
|
||||
return pretrained_policy_path
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_logging()
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ from safetensors.torch import save_file
|
||||
from lerobot.common.datasets.compute_stats import compute_stats
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
|
||||
from lerobot.common.datasets.utils import flatten_dict
|
||||
from lerobot.common.datasets.utils import create_branch, create_lerobot_dataset_card, flatten_dict
|
||||
|
||||
|
||||
def get_from_raw_to_lerobot_format_fn(raw_format: str):
|
||||
@@ -114,6 +114,14 @@ def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str
|
||||
)
|
||||
|
||||
|
||||
def push_dataset_card_to_hub(
|
||||
repo_id: str, revision: str | None, tags: list | None = None, text: str | None = None
|
||||
):
|
||||
"""Creates and pushes a LeRobotDataset Card with appropriate tags to easily find it on the hub."""
|
||||
card = create_lerobot_dataset_card(tags=tags, text=text)
|
||||
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=revision)
|
||||
|
||||
|
||||
def push_videos_to_hub(repo_id: str, videos_dir: str | Path, revision: str | None):
|
||||
"""Expect mp4 files to be all stored in a single "videos" directory.
|
||||
On the hugging face repositery, they will be uploaded in a "videos" directory at the root.
|
||||
@@ -213,10 +221,10 @@ def push_dataset_to_hub(
|
||||
if push_to_hub:
|
||||
hf_dataset.push_to_hub(repo_id, revision="main")
|
||||
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
|
||||
push_dataset_card_to_hub(repo_id, revision="main")
|
||||
if video:
|
||||
push_videos_to_hub(repo_id, videos_dir, revision="main")
|
||||
api = HfApi()
|
||||
api.create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
|
||||
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
|
||||
|
||||
if tests_data_dir:
|
||||
# get the first episode
|
||||
|
||||
@@ -120,7 +120,8 @@ def update_policy(
|
||||
policy.train()
|
||||
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
|
||||
output_dict = policy.forward(batch)
|
||||
loss = output_dict["loss"].mean()
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
loss = output_dict["loss"]
|
||||
grad_scaler.scale(loss).backward()
|
||||
|
||||
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
|
||||
@@ -149,12 +150,14 @@ def update_policy(
|
||||
policy.update()
|
||||
|
||||
info = {
|
||||
"loss": loss.item(),
|
||||
"grad_norm": float(grad_norm),
|
||||
"lr": optimizer.param_groups[0]["lr"],
|
||||
"update_s": time.perf_counter() - start_time,
|
||||
**{k: v.detach().mean().item() for k, v in output_dict.items() if "loss" in k},
|
||||
**{k: v for k, v in output_dict.items() if "loss" not in k},
|
||||
**{k: v for k, v in output_dict.items() if k != "loss"},
|
||||
}
|
||||
info.update({k: v for k, v in output_dict.items() if k not in info})
|
||||
|
||||
return info
|
||||
|
||||
|
||||
|
||||
@@ -50,33 +50,19 @@ python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id lerobot/pusht \
|
||||
--episodes 7 3 5 1 4
|
||||
```
|
||||
|
||||
- Run inference of a policy on the dataset and visualize the results:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id lerobot/pusht \
|
||||
--episodes 7 3 5 1 4
|
||||
-p lerobot/diffusion_pusht \
|
||||
--policy-overrides device=cpu
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import shutil
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
from flask import Flask, redirect, render_template, url_for
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.utils import get_pretrained_policy_path
|
||||
from lerobot.common.utils.utils import init_hydra_config, init_logging
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
|
||||
|
||||
class EpisodeSampler(torch.utils.data.Sampler):
|
||||
@@ -99,7 +85,6 @@ def run_server(
|
||||
port: str,
|
||||
static_folder: Path,
|
||||
template_folder: Path,
|
||||
has_policy: bool = False,
|
||||
):
|
||||
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
|
||||
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
|
||||
@@ -139,7 +124,7 @@ def run_server(
|
||||
dataset_info=dataset_info,
|
||||
videos_info=videos_info,
|
||||
ep_csv_url=ep_csv_url,
|
||||
has_policy=has_policy,
|
||||
has_policy=False,
|
||||
)
|
||||
|
||||
app.run(host=host, port=port)
|
||||
@@ -150,7 +135,7 @@ def get_ep_csv_fname(episode_id: int):
|
||||
return ep_csv_fname
|
||||
|
||||
|
||||
def write_episode_data_csv(output_dir, file_name, episode_index, dataset, inference_results=None):
|
||||
def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
|
||||
"""Write a csv file containg timeseries data of an episode (e.g. state and action).
|
||||
This file will be loaded by Dygraph javascript to plot data in real time."""
|
||||
from_idx = dataset.episode_data_index["from"][episode_index]
|
||||
@@ -158,7 +143,6 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset, infere
|
||||
|
||||
has_state = "observation.state" in dataset.hf_dataset.features
|
||||
has_action = "action" in dataset.hf_dataset.features
|
||||
has_inference = inference_results is not None
|
||||
|
||||
# init header of csv with state and action names
|
||||
header = ["timestamp"]
|
||||
@@ -168,13 +152,6 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset, infere
|
||||
if has_action:
|
||||
dim_action = len(dataset.hf_dataset["action"][0])
|
||||
header += [f"action_{i}" for i in range(dim_action)]
|
||||
if has_inference:
|
||||
if "action" in inference_results:
|
||||
dim_pred_action = inference_results["action"].shape[1]
|
||||
header += [f"pred_action_{i}" for i in range(dim_pred_action)]
|
||||
for key in inference_results:
|
||||
if "loss" in key:
|
||||
header += [key]
|
||||
|
||||
columns = ["timestamp"]
|
||||
if has_state:
|
||||
@@ -192,18 +169,6 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset, infere
|
||||
row += data[i]["action"].tolist()
|
||||
rows.append(row)
|
||||
|
||||
if has_inference:
|
||||
num_frames = len(rows)
|
||||
if "action" in inference_results:
|
||||
assert num_frames == inference_results["action"].shape[0]
|
||||
for i in range(num_frames):
|
||||
rows[i] += inference_results["action"][i].tolist()
|
||||
for key in inference_results:
|
||||
if "loss" in key:
|
||||
assert num_frames == inference_results[key].shape[0]
|
||||
for i in range(num_frames):
|
||||
rows[i] += [inference_results[key][i].item()]
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_dir / file_name, "w") as f:
|
||||
f.write(",".join(header) + "\n")
|
||||
@@ -221,75 +186,6 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
|
||||
]
|
||||
|
||||
|
||||
def run_inference(
|
||||
dataset, episode_index, policy, policy_method="select_action", num_workers=4, batch_size=32, device="cuda"
|
||||
):
|
||||
if policy_method not in ["select_action", "forward"]:
|
||||
raise ValueError(
|
||||
f"`policy_method` is expected to be 'select_action' or 'forward', but '{policy_method}' is provided instead."
|
||||
)
|
||||
|
||||
policy.eval()
|
||||
policy.to(device)
|
||||
|
||||
logging.info("Loading dataloader")
|
||||
episode_sampler = EpisodeSampler(dataset, episode_index)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=num_workers,
|
||||
# When using `select_action`, we set batch size 1 so that we feed 1 frame at a time, in a continuous fashion.
|
||||
batch_size=1 if policy_method == "select_action" else batch_size,
|
||||
sampler=episode_sampler,
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
warned_ndim_eq_0 = False
|
||||
warned_ndim_gt_2 = False
|
||||
|
||||
logging.info("Running inference")
|
||||
inference_results = {}
|
||||
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
|
||||
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
||||
with torch.inference_mode():
|
||||
if policy_method == "select_action":
|
||||
gt_action = batch.pop("action")
|
||||
output_dict = {"action": policy.select_action(batch)}
|
||||
batch["action"] = gt_action
|
||||
elif policy_method == "forward":
|
||||
output_dict = policy.forward(batch)
|
||||
# TODO(rcadene): Save and display all predicted actions at a given timestamp
|
||||
# Save predicted action for the next timestamp only
|
||||
output_dict["action"] = output_dict["action"][:, 0, :]
|
||||
|
||||
for key in output_dict:
|
||||
if output_dict[key].ndim == 0:
|
||||
if not warned_ndim_eq_0:
|
||||
warnings.warn(
|
||||
f"Ignore output key '{key}'. Its value is a scalar instead of a vector. It might have been aggregated over the batch dimension (e.g. `loss.mean()`).",
|
||||
stacklevel=1,
|
||||
)
|
||||
warned_ndim_eq_0 = True
|
||||
continue
|
||||
|
||||
if output_dict[key].ndim > 2:
|
||||
if not warned_ndim_gt_2:
|
||||
warnings.warn(
|
||||
f"Ignore output key '{key}'. Its value is a tensor of {output_dict[key].ndim} dimensions instead of a vector.",
|
||||
stacklevel=1,
|
||||
)
|
||||
warned_ndim_gt_2 = True
|
||||
continue
|
||||
|
||||
if key not in inference_results:
|
||||
inference_results[key] = []
|
||||
inference_results[key].append(output_dict[key].to("cpu"))
|
||||
|
||||
for key in inference_results:
|
||||
inference_results[key] = torch.cat(inference_results[key])
|
||||
|
||||
return inference_results
|
||||
|
||||
|
||||
def visualize_dataset_html(
|
||||
repo_id: str,
|
||||
root: Path | None = None,
|
||||
@@ -299,28 +195,10 @@ def visualize_dataset_html(
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 9090,
|
||||
force_override: bool = False,
|
||||
policy_method: str = "select_action",
|
||||
pretrained_policy_name_or_path: str | None = None,
|
||||
policy_overrides: list[str] | None = None,
|
||||
) -> Path | None:
|
||||
init_logging()
|
||||
|
||||
has_policy = pretrained_policy_name_or_path is not None
|
||||
|
||||
if has_policy:
|
||||
logging.info("Loading policy")
|
||||
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
|
||||
|
||||
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides)
|
||||
dataset = make_dataset(hydra_cfg)
|
||||
policy = make_policy(hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
|
||||
|
||||
if policy_method == "select_action":
|
||||
# Do not load previous observations or future actions, to simulate that the observations come from
|
||||
# an environment.
|
||||
dataset.delta_timestamps = None
|
||||
else:
|
||||
dataset = LeRobotDataset(repo_id, root=root)
|
||||
dataset = LeRobotDataset(repo_id, root=root)
|
||||
|
||||
if not dataset.video:
|
||||
raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.")
|
||||
@@ -328,11 +206,6 @@ def visualize_dataset_html(
|
||||
if output_dir is None:
|
||||
output_dir = f"outputs/visualize_dataset_html/{repo_id}"
|
||||
|
||||
if has_policy:
|
||||
ckpt_str = pretrained_policy_path.parts[-2]
|
||||
exp_name = pretrained_policy_path.parts[-4]
|
||||
output_dir += f"_{exp_name}_{ckpt_str}_{policy_method}"
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
if output_dir.exists():
|
||||
if force_override:
|
||||
@@ -357,31 +230,13 @@ def visualize_dataset_html(
|
||||
|
||||
logging.info("Writing CSV files")
|
||||
for episode_index in tqdm.tqdm(episodes):
|
||||
inference_results = None
|
||||
if has_policy:
|
||||
inference_results_path = output_dir / f"episode_{episode_index}.safetensors"
|
||||
if inference_results_path.exists():
|
||||
inference_results = load_file(inference_results_path)
|
||||
else:
|
||||
inference_results = run_inference(
|
||||
dataset,
|
||||
episode_index,
|
||||
policy,
|
||||
policy_method,
|
||||
num_workers=hydra_cfg.training.num_workers,
|
||||
batch_size=hydra_cfg.training.batch_size,
|
||||
device=hydra_cfg.device,
|
||||
)
|
||||
inference_results_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
save_file(inference_results, inference_results_path)
|
||||
|
||||
# write states and actions in a csv (it can be slow for big datasets)
|
||||
ep_csv_fname = get_ep_csv_fname(episode_index)
|
||||
# TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors?
|
||||
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset, inference_results)
|
||||
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset)
|
||||
|
||||
if serve:
|
||||
run_server(dataset, episodes, host, port, static_dir, template_dir, has_policy)
|
||||
run_server(dataset, episodes, host, port, static_dir, template_dir)
|
||||
|
||||
|
||||
def main():
|
||||
@@ -437,28 +292,6 @@ def main():
|
||||
help="Delete the output directory if it exists already.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--policy-method",
|
||||
type=str,
|
||||
default="select_action",
|
||||
choices=["select_action", "forward"],
|
||||
help="Python method used to run the inference. By default, set to `select_action` used during evaluation to output the sequence of actions. Can bet set to `forward` used during training to compute the loss.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--pretrained-policy-name-or-path",
|
||||
type=str,
|
||||
help=(
|
||||
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
||||
"saved using `Policy.save_pretrained`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--policy-overrides",
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override policy config values (use dots for.nested=overrides)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
visualize_dataset_html(**vars(args))
|
||||
|
||||
|
||||
|
Before Width: | Height: | Size: 416 KiB |
|
Before Width: | Height: | Size: 446 KiB |
BIN
media/koch/follower_rest.webp
Normal file
|
After Width: | Height: | Size: 326 KiB |
BIN
media/koch/follower_rotated.webp
Normal file
|
After Width: | Height: | Size: 312 KiB |
BIN
media/koch/follower_zero.webp
Normal file
|
After Width: | Height: | Size: 469 KiB |
|
Before Width: | Height: | Size: 318 KiB |
|
Before Width: | Height: | Size: 420 KiB |
BIN
media/koch/leader_rest.webp
Normal file
|
After Width: | Height: | Size: 339 KiB |
BIN
media/koch/leader_rotated.webp
Normal file
|
After Width: | Height: | Size: 232 KiB |
BIN
media/koch/leader_zero.webp
Normal file
|
After Width: | Height: | Size: 484 KiB |
BIN
media/tutorial/koch_v1_1_leader_follower.webp
Normal file
|
After Width: | Height: | Size: 58 KiB |
BIN
media/tutorial/visualize_dataset_html.webp
Normal file
|
After Width: | Height: | Size: 121 KiB |
60
poetry.lock
generated
@@ -1373,6 +1373,7 @@ files = [
|
||||
filelock = "*"
|
||||
fsspec = ">=2023.5.0"
|
||||
hf-transfer = {version = ">=0.1.4", optional = true, markers = "extra == \"hf-transfer\""}
|
||||
InquirerPy = {version = "0.3.4", optional = true, markers = "extra == \"cli\""}
|
||||
packaging = ">=20.9"
|
||||
pyyaml = ">=5.1"
|
||||
requests = "*"
|
||||
@@ -1559,6 +1560,24 @@ files = [
|
||||
{file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "inquirerpy"
|
||||
version = "0.3.4"
|
||||
description = "Python port of Inquirer.js (A collection of common interactive command-line user interfaces)"
|
||||
optional = false
|
||||
python-versions = ">=3.7,<4.0"
|
||||
files = [
|
||||
{file = "InquirerPy-0.3.4-py3-none-any.whl", hash = "sha256:c65fdfbac1fa00e3ee4fb10679f4d3ed7a012abf4833910e63c295827fe2a7d4"},
|
||||
{file = "InquirerPy-0.3.4.tar.gz", hash = "sha256:89d2ada0111f337483cb41ae31073108b2ec1e618a49d7110b0d7ade89fc197e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pfzy = ">=0.3.1,<0.4.0"
|
||||
prompt-toolkit = ">=3.0.1,<4.0.0"
|
||||
|
||||
[package.extras]
|
||||
docs = ["Sphinx (>=4.1.2,<5.0.0)", "furo (>=2021.8.17-beta.43,<2022.0.0)", "myst-parser (>=0.15.1,<0.16.0)", "sphinx-autobuild (>=2021.3.14,<2022.0.0)", "sphinx-copybutton (>=0.4.0,<0.5.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "intel-openmp"
|
||||
version = "2021.4.0"
|
||||
@@ -2564,6 +2583,20 @@ other = ["pillow (>=8.0.1)"]
|
||||
sisl = ["box2d-py (==2.3.5)", "pygame (==2.3.0)", "pymunk (==6.2.0)", "scipy (>=1.4.1)"]
|
||||
testing = ["AutoROM", "pre-commit", "pynput", "pytest", "pytest-cov", "pytest-markdown-docs", "pytest-xdist"]
|
||||
|
||||
[[package]]
|
||||
name = "pfzy"
|
||||
version = "0.3.4"
|
||||
description = "Python port of the fzy fuzzy string matching algorithm"
|
||||
optional = false
|
||||
python-versions = ">=3.7,<4.0"
|
||||
files = [
|
||||
{file = "pfzy-0.3.4-py3-none-any.whl", hash = "sha256:5f50d5b2b3207fa72e7ec0ef08372ef652685470974a107d0d4999fc5a903a96"},
|
||||
{file = "pfzy-0.3.4.tar.gz", hash = "sha256:717ea765dd10b63618e7298b2d98efd819e0b30cd5905c9707223dceeb94b3f1"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
docs = ["Sphinx (>=4.1.2,<5.0.0)", "furo (>=2021.8.17-beta.43,<2022.0.0)", "myst-parser (>=0.15.1,<0.16.0)", "sphinx-autobuild (>=2021.3.14,<2022.0.0)", "sphinx-copybutton (>=0.4.0,<0.5.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "pillow"
|
||||
version = "10.4.0"
|
||||
@@ -2710,6 +2743,20 @@ nodeenv = ">=0.11.1"
|
||||
pyyaml = ">=5.1"
|
||||
virtualenv = ">=20.10.0"
|
||||
|
||||
[[package]]
|
||||
name = "prompt-toolkit"
|
||||
version = "3.0.47"
|
||||
description = "Library for building powerful interactive command lines in Python"
|
||||
optional = false
|
||||
python-versions = ">=3.7.0"
|
||||
files = [
|
||||
{file = "prompt_toolkit-3.0.47-py3-none-any.whl", hash = "sha256:0d7bfa67001d5e39d02c224b663abc33687405033a8c422d0d675a5a13361d10"},
|
||||
{file = "prompt_toolkit-3.0.47.tar.gz", hash = "sha256:1e1b29cb58080b1e69f207c893a1a7bf16d127a5c30c9d17a25a5d77792e5360"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
wcwidth = "*"
|
||||
|
||||
[[package]]
|
||||
name = "protobuf"
|
||||
version = "5.27.2"
|
||||
@@ -4216,6 +4263,17 @@ perf = ["orjson"]
|
||||
sweeps = ["sweeps (>=0.2.0)"]
|
||||
workspaces = ["wandb-workspaces"]
|
||||
|
||||
[[package]]
|
||||
name = "wcwidth"
|
||||
version = "0.2.13"
|
||||
description = "Measures the displayed width of unicode strings in a terminal"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859"},
|
||||
{file = "wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "werkzeug"
|
||||
version = "3.0.3"
|
||||
@@ -4503,4 +4561,4 @@ xarm = ["gym-xarm"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<3.13"
|
||||
content-hash = "25d5a270d770d37b13a93bf72868d3b9e683f8af5252b6332ec926a26fd0c096"
|
||||
content-hash = "a340f2ed23db2f3c371c494cbc9a33392e122ed6713e6098277a87b3fb805f2b"
|
||||
|
||||
@@ -43,7 +43,7 @@ opencv-python = ">=4.9.0"
|
||||
diffusers = ">=0.27.2"
|
||||
torchvision = ">=0.17.1"
|
||||
h5py = ">=3.10.0"
|
||||
huggingface-hub = {extras = ["hf-transfer"], version = ">=0.23.0"}
|
||||
huggingface-hub = {extras = ["hf-transfer", "cli"], version = ">=0.23.0"}
|
||||
gymnasium = ">=0.29.1"
|
||||
cmake = ">=3.29.0.1"
|
||||
gym-dora = { git = "https://github.com/dora-rs/dora-lerobot.git", subdirectory = "gym_dora", optional = true }
|
||||
|
||||
@@ -15,7 +15,9 @@
|
||||
# limitations under the License.
|
||||
import pytest
|
||||
|
||||
from .utils import DEVICE
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
|
||||
from .utils import DEVICE, KOCH_ROBOT_CONFIG_PATH
|
||||
|
||||
|
||||
def pytest_collection_finish():
|
||||
@@ -27,11 +29,12 @@ def is_koch_available():
|
||||
try:
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
|
||||
robot = make_robot("koch")
|
||||
robot_cfg = init_hydra_config(KOCH_ROBOT_CONFIG_PATH)
|
||||
robot = make_robot(robot_cfg)
|
||||
robot.connect()
|
||||
del robot
|
||||
return True
|
||||
except Exception as e:
|
||||
print("An alexander koch robot is not available.")
|
||||
print("A koch robot is not available.")
|
||||
print(e)
|
||||
return False
|
||||
|
||||
@@ -13,18 +13,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Save the policy tests artifacts.
|
||||
|
||||
Note: Run on the cluster
|
||||
|
||||
Example of usage:
|
||||
```bash
|
||||
DATA_DIR=tests/data python tests/scripts/save_policy_to_safetensors.py
|
||||
```
|
||||
"""
|
||||
|
||||
import platform
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
@@ -66,7 +54,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
|
||||
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
|
||||
loss = output_dict["loss"]
|
||||
|
||||
loss.mean().backward()
|
||||
loss.backward()
|
||||
grad_stats = {}
|
||||
for key, param in policy.named_parameters():
|
||||
if param.requires_grad:
|
||||
@@ -108,21 +96,10 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
|
||||
print(f"Overwrite existing safetensors in '{env_policy_dir}':")
|
||||
print(f" - Validate with: `git add {env_policy_dir}`")
|
||||
print(f" - Revert with: `git checkout -- {env_policy_dir}`")
|
||||
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
|
||||
|
||||
from safetensors.torch import load_file
|
||||
|
||||
if (env_policy_dir / "output_dict.safetensors").exists():
|
||||
prev_loss = load_file(env_policy_dir / "output_dict.safetensors")["loss"]
|
||||
print(f"Previous loss={prev_loss}")
|
||||
print(f"New loss={output_dict['loss'].mean()}")
|
||||
print()
|
||||
|
||||
if env_policy_dir.exists():
|
||||
shutil.rmtree(env_policy_dir)
|
||||
|
||||
env_policy_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
|
||||
save_file(output_dict, env_policy_dir / "output_dict.safetensors")
|
||||
save_file(grad_stats, env_policy_dir / "grad_stats.safetensors")
|
||||
save_file(param_stats, env_policy_dir / "param_stats.safetensors")
|
||||
@@ -130,32 +107,27 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if platform.machine() != "x86_64":
|
||||
raise OSError("Generate policy artifacts on x86_64 machine since it is used for the unit tests. ")
|
||||
|
||||
env_policies = [
|
||||
("xarm", "tdmpc", ["policy.use_mpc=false"], "use_policy"),
|
||||
("xarm", "tdmpc", ["policy.use_mpc=true"], "use_mpc"),
|
||||
(
|
||||
"pusht",
|
||||
"diffusion",
|
||||
[
|
||||
"policy.n_action_steps=8",
|
||||
"policy.num_inference_steps=10",
|
||||
"policy.down_dims=[128, 256, 512]",
|
||||
],
|
||||
"",
|
||||
),
|
||||
("aloha", "act", ["policy.n_action_steps=10"], ""),
|
||||
("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
|
||||
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""),
|
||||
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"], ""),
|
||||
# ("xarm", "tdmpc", ["policy.use_mpc=false"], "use_policy"),
|
||||
# ("xarm", "tdmpc", ["policy.use_mpc=true"], "use_mpc"),
|
||||
# (
|
||||
# "pusht",
|
||||
# "diffusion",
|
||||
# [
|
||||
# "policy.n_action_steps=8",
|
||||
# "policy.num_inference_steps=10",
|
||||
# "policy.down_dims=[128, 256, 512]",
|
||||
# ],
|
||||
# "",
|
||||
# ),
|
||||
# ("aloha", "act", ["policy.n_action_steps=10"], ""),
|
||||
# ("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
|
||||
# ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""),
|
||||
# ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"], ""),
|
||||
]
|
||||
if len(env_policies) == 0:
|
||||
raise RuntimeError("No policies were provided!")
|
||||
for env, policy, extra_overrides, file_name_extra in env_policies:
|
||||
print(f"env={env} policy={policy} extra_overrides={extra_overrides}")
|
||||
save_policy_to_safetensors(
|
||||
"tests/data/save_policy_to_safetensors", env, policy, extra_overrides, file_name_extra
|
||||
)
|
||||
print()
|
||||
|
||||
@@ -3,13 +3,20 @@ from pathlib import Path
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from lerobot.scripts.control_robot import record_dataset, replay_episode, run_policy, teleoperate
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_koch
|
||||
from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, KOCH_ROBOT_CONFIG_PATH, require_koch
|
||||
|
||||
|
||||
def make_robot_(overrides=None):
|
||||
robot_cfg = init_hydra_config(KOCH_ROBOT_CONFIG_PATH, overrides)
|
||||
robot = make_robot(robot_cfg)
|
||||
return robot
|
||||
|
||||
|
||||
@require_koch
|
||||
# `require_koch` uses `request` to access `is_koch_available` fixture
|
||||
def test_teleoperate(request):
|
||||
robot = make_robot("koch")
|
||||
robot = make_robot_()
|
||||
teleoperate(robot, teleop_time_s=1)
|
||||
teleoperate(robot, fps=30, teleop_time_s=1)
|
||||
teleoperate(robot, fps=60, teleop_time_s=1)
|
||||
@@ -17,20 +24,35 @@ def test_teleoperate(request):
|
||||
|
||||
|
||||
@require_koch
|
||||
def test_record_dataset_and_replay_episode_and_run_policy(tmpdir, request):
|
||||
robot_name = "koch"
|
||||
def test_calibrate(request):
|
||||
robot = make_robot_()
|
||||
calibrate(robot)
|
||||
del robot
|
||||
|
||||
|
||||
@require_koch
|
||||
def test_record_without_cameras(tmpdir, request):
|
||||
root = Path(tmpdir)
|
||||
repo_id = "lerobot/debug"
|
||||
|
||||
robot = make_robot_(overrides=["~cameras"])
|
||||
record(robot, fps=30, root=root, repo_id=repo_id, warmup_time_s=1, episode_time_s=1, num_episodes=2)
|
||||
|
||||
|
||||
@require_koch
|
||||
def test_record_and_replay_and_policy(tmpdir, request):
|
||||
env_name = "koch_real"
|
||||
policy_name = "act_koch_real"
|
||||
|
||||
root = Path(tmpdir)
|
||||
repo_id = "lerobot/debug"
|
||||
|
||||
robot = make_robot(robot_name)
|
||||
dataset = record_dataset(
|
||||
robot = make_robot_()
|
||||
dataset = record(
|
||||
robot, fps=30, root=root, repo_id=repo_id, warmup_time_s=1, episode_time_s=1, num_episodes=2
|
||||
)
|
||||
|
||||
replay_episode(robot, episode=0, fps=30, root=root, repo_id=repo_id)
|
||||
replay(robot, episode=0, fps=30, root=root, repo_id=repo_id)
|
||||
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH,
|
||||
@@ -43,6 +65,6 @@ def test_record_dataset_and_replay_episode_and_run_policy(tmpdir, request):
|
||||
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
||||
|
||||
run_policy(robot, policy, cfg, run_time_s=1)
|
||||
record(robot, policy, cfg, run_time_s=1)
|
||||
|
||||
del robot
|
||||
|
||||
@@ -23,6 +23,7 @@ import einops
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import HfApi
|
||||
from safetensors.torch import load_file
|
||||
|
||||
import lerobot
|
||||
@@ -34,6 +35,7 @@ from lerobot.common.datasets.compute_stats import (
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
|
||||
from lerobot.common.datasets.utils import (
|
||||
create_branch,
|
||||
flatten_dict,
|
||||
hf_transform_to_torch,
|
||||
load_previous_and_future_frames,
|
||||
@@ -385,3 +387,29 @@ def test_aggregate_stats():
|
||||
for agg_fn in ["mean", "min", "max"]:
|
||||
assert torch.allclose(stats[data_key][agg_fn], einops.reduce(data, "n -> 1", agg_fn))
|
||||
assert torch.allclose(stats[data_key]["std"], torch.std(data, correction=0))
|
||||
|
||||
|
||||
@pytest.mark.skip("Requires internet access")
|
||||
def test_create_branch():
|
||||
api = HfApi()
|
||||
|
||||
repo_id = "cadene/test_create_branch"
|
||||
repo_type = "dataset"
|
||||
branch = "test"
|
||||
ref = f"refs/heads/{branch}"
|
||||
|
||||
# Prepare a repo with a test branch
|
||||
api.delete_repo(repo_id, repo_type=repo_type, missing_ok=True)
|
||||
api.create_repo(repo_id, repo_type=repo_type)
|
||||
create_branch(repo_id, repo_type=repo_type, branch=branch)
|
||||
|
||||
# Make sure the test branch exists
|
||||
branches = api.list_repo_refs(repo_id, repo_type=repo_type).branches
|
||||
refs = [branch.ref for branch in branches]
|
||||
assert ref in refs
|
||||
|
||||
# Overwrite it
|
||||
create_branch(repo_id, repo_type=repo_type, branch=branch)
|
||||
|
||||
# Clean
|
||||
api.delete_repo(repo_id, repo_type=repo_type)
|
||||
|
||||
@@ -1,33 +1,54 @@
|
||||
# TODO(rcadene): measure fps in nightly?
|
||||
# TODO(rcadene): test logs
|
||||
# TODO(rcadene): test calibration
|
||||
# TODO(rcadene): add compatibility with other motors bus
|
||||
|
||||
import time
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
from tests.utils import require_koch
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from tests.utils import KOCH_ROBOT_CONFIG_PATH, require_koch
|
||||
|
||||
|
||||
def make_motors_bus():
|
||||
robot_cfg = init_hydra_config(KOCH_ROBOT_CONFIG_PATH)
|
||||
# Instantiating a common motors structure.
|
||||
# Here the one from Alexander Koch follower arm.
|
||||
motors_bus = hydra.utils.instantiate(robot_cfg.leader_arms.main)
|
||||
return motors_bus
|
||||
|
||||
|
||||
@require_koch
|
||||
def test_find_port(request):
|
||||
from lerobot.common.robot_devices.motors.dynamixel import find_port
|
||||
|
||||
find_port()
|
||||
|
||||
|
||||
@require_koch
|
||||
def test_configure_motors_all_ids_1(request):
|
||||
# This test expect the configuration was already correct.
|
||||
motors_bus = make_motors_bus()
|
||||
motors_bus.connect()
|
||||
motors_bus.write("Baud_Rate", [0] * len(motors_bus.motors))
|
||||
motors_bus.set_bus_baudrate(9_600)
|
||||
motors_bus.write("ID", [1] * len(motors_bus.motors))
|
||||
del motors_bus
|
||||
|
||||
# Test configure
|
||||
motors_bus = make_motors_bus()
|
||||
motors_bus.connect()
|
||||
assert motors_bus.are_motors_configured()
|
||||
del motors_bus
|
||||
|
||||
|
||||
@require_koch
|
||||
def test_motors_bus(request):
|
||||
# TODO(rcadene): measure fps in nightly?
|
||||
# TODO(rcadene): test logs
|
||||
# TODO(rcadene): test calibration
|
||||
# TODO(rcadene): add compatibility with other motors bus
|
||||
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
|
||||
|
||||
# Test instantiating a common motors structure.
|
||||
# Here the one from Alexander Koch follower arm.
|
||||
port = "/dev/tty.usbmodem575E0032081"
|
||||
motors = {
|
||||
# name: (index, model)
|
||||
"shoulder_pan": (1, "xl430-w250"),
|
||||
"shoulder_lift": (2, "xl430-w250"),
|
||||
"elbow_flex": (3, "xl330-m288"),
|
||||
"wrist_flex": (4, "xl330-m288"),
|
||||
"wrist_roll": (5, "xl330-m288"),
|
||||
"gripper": (6, "xl330-m288"),
|
||||
}
|
||||
motors_bus = DynamixelMotorsBus(port, motors)
|
||||
motors_bus = make_motors_bus()
|
||||
|
||||
# Test reading and writting before connecting raises an error
|
||||
with pytest.raises(RobotDeviceNotConnectedError):
|
||||
@@ -41,7 +62,7 @@ def test_motors_bus(request):
|
||||
del motors_bus
|
||||
|
||||
# Test connecting
|
||||
motors_bus = DynamixelMotorsBus(port, motors)
|
||||
motors_bus = make_motors_bus()
|
||||
motors_bus.connect()
|
||||
|
||||
# Test connecting twice raises an error
|
||||
@@ -52,7 +73,7 @@ def test_motors_bus(request):
|
||||
motors_bus.write("Torque_Enable", 0)
|
||||
values = motors_bus.read("Torque_Enable")
|
||||
assert isinstance(values, np.ndarray)
|
||||
assert len(values) == len(motors)
|
||||
assert len(values) == len(motors_bus.motors)
|
||||
assert (values == 0).all()
|
||||
|
||||
# Test writing torque on a specific motor
|
||||
@@ -83,10 +104,3 @@ def test_motors_bus(request):
|
||||
time.sleep(1)
|
||||
new_values = motors_bus.read("Present_Position")
|
||||
assert (new_values == values).all()
|
||||
|
||||
|
||||
@require_koch
|
||||
def test_find_port(request):
|
||||
from lerobot.common.robot_devices.motors.dynamixel import find_port
|
||||
|
||||
find_port()
|
||||
|
||||
@@ -147,11 +147,10 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
# Check that we run select_actions and get the appropriate output.
|
||||
env = make_env(cfg, n_envs=2)
|
||||
|
||||
batch_size = 2
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
batch_size=batch_size,
|
||||
batch_size=2,
|
||||
shuffle=True,
|
||||
pin_memory=DEVICE != "cpu",
|
||||
drop_last=True,
|
||||
@@ -165,19 +164,12 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
|
||||
# Test updating the policy (and test that it does not mutate the batch)
|
||||
batch_ = deepcopy(batch)
|
||||
out = policy.forward(batch)
|
||||
policy.forward(batch)
|
||||
assert set(batch) == set(batch_), "Batch keys are not the same after a forward pass."
|
||||
assert all(
|
||||
torch.equal(batch[k], batch_[k]) for k in batch
|
||||
), "Batch values are not the same after a forward pass."
|
||||
|
||||
# Test loss can be visualized using visualize_dataset_html.py
|
||||
for key in out:
|
||||
if "loss" in key:
|
||||
assert (
|
||||
out[key].ndim == 1 and out[key].shape[0] == batch_size
|
||||
), f"1 loss value per item in the batch is expected, but {out[key].shape} provided instead."
|
||||
|
||||
# reset the policy and environment
|
||||
policy.reset()
|
||||
observation, _ = env.reset(seed=cfg.seed)
|
||||
@@ -242,7 +234,6 @@ def test_policy_defaults(policy_name: str):
|
||||
[
|
||||
("xarm", "tdmpc"),
|
||||
("pusht", "diffusion"),
|
||||
("pusht", "vqbet"),
|
||||
("aloha", "act"),
|
||||
],
|
||||
)
|
||||
@@ -259,7 +250,7 @@ def test_yaml_matches_dataclass(env_name: str, policy_name: str):
|
||||
def test_save_and_load_pretrained(policy_name: str):
|
||||
policy_cls, _ = get_policy_and_config_classes(policy_name)
|
||||
policy: Policy = policy_cls()
|
||||
save_dir = f"/tmp/test_save_and_load_pretrained_{policy_cls.__name__}"
|
||||
save_dir = "/tmp/test_save_and_load_pretrained_{policy_cls.__name__}"
|
||||
policy.save_pretrained(save_dir)
|
||||
policy_ = policy_cls.from_pretrained(save_dir)
|
||||
assert all(torch.equal(p, p_) for p, p_ in zip(policy.parameters(), policy_.parameters(), strict=True))
|
||||
@@ -374,7 +365,6 @@ def test_normalize(insert_temporal_dim):
|
||||
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
||||
"",
|
||||
),
|
||||
("pusht", "vqbet", "[]", ""),
|
||||
("aloha", "act", ["policy.n_action_steps=10"], ""),
|
||||
("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
|
||||
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""),
|
||||
@@ -471,3 +461,7 @@ def test_act_temporal_ensembler():
|
||||
assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max"))
|
||||
# Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error.
|
||||
assert torch.allclose(online_avg, offline_avg, atol=1e-4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_act_temporal_ensembler()
|
||||
|
||||
@@ -25,13 +25,13 @@ from lerobot.scripts.visualize_dataset import visualize_dataset
|
||||
["lerobot/pusht"],
|
||||
)
|
||||
@pytest.mark.parametrize("root", [Path(__file__).parent / "data"])
|
||||
def test_visualize_dataset_root(tmpdir, repo_id, root):
|
||||
def test_visualize_local_dataset(tmpdir, repo_id, root):
|
||||
rrd_path = visualize_dataset(
|
||||
repo_id,
|
||||
root=root,
|
||||
episode_index=0,
|
||||
batch_size=32,
|
||||
save=True,
|
||||
output_dir=tmpdir,
|
||||
root=root,
|
||||
)
|
||||
assert rrd_path.exists()
|
||||
|
||||
@@ -18,12 +18,7 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.logger import Logger
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from lerobot.scripts.visualize_dataset_html import visualize_dataset_html
|
||||
from tests.utils import DEFAULT_CONFIG_PATH
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -39,34 +34,3 @@ def test_visualize_dataset_html(tmpdir, repo_id):
|
||||
serve=False,
|
||||
)
|
||||
assert (tmpdir / "static" / "episode_0.csv").exists()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"repo_id, policy_method",
|
||||
[
|
||||
("lerobot/pusht", "select_action"),
|
||||
("lerobot/pusht", "forward"),
|
||||
],
|
||||
)
|
||||
def test_visualize_dataset_policy_ckpt_path(tmpdir, repo_id, policy_method):
|
||||
tmpdir = Path(tmpdir)
|
||||
|
||||
# Create a policy
|
||||
cfg = init_hydra_config(DEFAULT_CONFIG_PATH, overrides=["device=cpu"])
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(cfg, dataset_stats=dataset.stats)
|
||||
|
||||
# Save a checkpoint
|
||||
logger = Logger(cfg, tmpdir)
|
||||
logger.save_model(tmpdir, policy)
|
||||
|
||||
visualize_dataset_html(
|
||||
repo_id,
|
||||
episodes=[0],
|
||||
output_dir=tmpdir,
|
||||
serve=False,
|
||||
pretrained_policy_name_or_path=tmpdir,
|
||||
policy_method=policy_method,
|
||||
)
|
||||
assert (tmpdir / "static" / "episode_0.csv").exists()
|
||||
assert (tmpdir / "episode_0.safetensors").exists()
|
||||
|
||||
@@ -23,6 +23,7 @@ from lerobot.common.utils.import_utils import is_package_available
|
||||
|
||||
# Pass this as the first argument to init_hydra_config.
|
||||
DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
|
||||
KOCH_ROBOT_CONFIG_PATH = "lerobot/configs/robot/koch.yaml"
|
||||
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
@@ -161,6 +162,7 @@ def require_koch(func):
|
||||
if request is None:
|
||||
raise ValueError("The 'request' fixture must be passed to the test function as a parameter.")
|
||||
|
||||
# The function `is_koch_available` is defined in `tests/conftest.py`
|
||||
if not request.getfixturevalue("is_koch_available"):
|
||||
pytest.skip("An alexander koch robot is not available.")
|
||||
return func(*args, **kwargs)
|
||||
|
||||