Compare commits

..

9 Commits

Author SHA1 Message Date
Thomas Wolf
9b5d2fd37d fix aloha conversion changes 2024-05-31 11:31:28 +02:00
Remi Cadene
97ea288084 Add dora_aloha_real_act_real and dora_aloha_real_act_real_no_state test artifacts 2024-05-30 17:56:46 +00:00
Remi Cadene
671ad93b6c Rename dora_aloha_real, WIP test_policies 2024-05-30 17:54:59 +00:00
Remi Cadene
b7b5c3b4ff small fix 2024-05-30 13:38:19 +00:00
Remi Cadene
1397036a6b small fix 2024-05-30 13:36:34 +00:00
Remi Cadene
c1570e40c6 Add dora-lerobot to pyproject 2024-05-30 13:35:28 +00:00
Remi Cadene
8d847a58ef Rename Aloha2 to Aloha 2024-05-30 13:35:02 +00:00
Remi Cadene
48f974bb9e fix 2024-05-30 12:10:44 +00:00
Remi Cadene
511e39bdb8 Add aloha2_real, Add act_real, Fix vae=false, Add support for no state 2024-05-30 12:06:57 +00:00
55 changed files with 266 additions and 3033 deletions

View File

@@ -127,21 +127,13 @@ wandb login
Check out [example 1](./examples/1_load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically download data from the Hugging Face hub.
You can also locally visualize episodes from a dataset on the hub by executing our script from the command line:
You can also locally visualize episodes from a dataset by executing our script from the command line:
```bash
python lerobot/scripts/visualize_dataset.py \
--repo-id lerobot/pusht \
--episode-index 0
```
or from a dataset in a local folder with the root `DATA_DIR` environment variable
```bash
DATA_DIR='./my_local_data_dir' python lerobot/scripts/visualize_dataset.py \
--repo-id lerobot/pusht \
--episode-index 0
```
It will open `rerun.io` and display the camera streams, robot states and actions, like this:
https://github-production-user-asset-6210df.s3.amazonaws.com/4681518/328035972-fd46b787-b532-47e2-bb6f-fd536a55a7ed.mov?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240505%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240505T172924Z&X-Amz-Expires=300&X-Amz-Signature=d680b26c532eeaf80740f08af3320d22ad0b8a4e4da1bcc4f33142c15b509eda&X-Amz-SignedHeaders=host&actor_id=24889239&key_id=0&repo_id=748713144
@@ -149,51 +141,6 @@ https://github-production-user-asset-6210df.s3.amazonaws.com/4681518/328035972-f
Our script can also visualize datasets stored on a distant server. See `python lerobot/scripts/visualize_dataset.py --help` for more instructions.
### The `LeRobotDataset` format
A dataset in `LeRobotDataset` format is very simple to use. It can be loaded from a repository on the Hugging Face hub or a local folder simply with e.g. `dataset = LeRobotDataset("lerobot/aloha_static_coffee")` and can be indexed into like any Hugging Face and Pytorch dataset. For instance `dataset[0]` will retrieve a sample of the dataset observations and actions in pytorch tensors format ready to be fed to a model.
A specificity of `LeRobotDataset` is that we can retrieve several frames for one sample query. By setting `delta_timestamps` to a list of delta timestamps, e.g. `delta_timestamps = {"observation.image": [-1, -0.5, -0.2, 0]}` one can retrieve, for each query, 4 images including one at -1 second before the current time step, the two others at -0.5 second and -0.2, and the final one at the current time step (0 second). See example [1_load_lerobot_dataset.py](examples/1_load_lerobot_dataset.py) for more details on `delta_timestamps`.
Under the hood, the `LeRobotDataset` format makes use of several ways to serialize data which can be useful to understand if you plan to work more closely with this format. We tried to make a flexible yet simple dataset format that would cover most type of features and specificities present in reinforcement learning and robotics, in simulation and in real-world, with a focus on cameras and robot states.
Here are the important details and internal structure organization of a typical `LeRobotDataset` instantiated with `dataset = LeRobotDataset("lerobot/aloha_static_coffee")`. The exact features will change from dataset to dataset but not the main aspects:
```
dataset attributes:
├ hf_dataset: a Hugging Face dataset (backed by Arrow/parquet). Typical features example:
│ ├ observation.images.cam_high: VideoFrame
│ │ VideoFrame = {'path': path to a mp4 video, 'timestamp': float32 timestamp in the video}
│ ├ observation.state: List of float32: position of an arm joints (for instance)
│ ... (more observations)
│ ├ action: List of float32
│ ├ episode_index: int64: index of the episode for this sample
│ ├ frame_index: int64: index of the frame for this sample in the episode ; starts at 0 for each episode
│ ├ timestamp: float32: timestamp in the episode
│ ├ next.done: bool: indicates the end of en episode ; True for the last frame in each episode
│ └ index: int64: general index in the whole dataset
├ episode_data_index: contains 2 tensors with the start and end indices of each episode
│ ├ from: 1D int64 tensor of first frame index for each episode: shape (num episodes,) starts with 0
│ └ to: 1D int64 tensor of last frame index for each episode: shape (num episodes,)
├ stats: a dictionary of statistics (max, mean, min, std) for each feature in the dataset, for instance
│ ├ observation.images.cam_high: {'max': tensor with same number of dimensions (e.g. `(c, 1, 1)` for images, `(c,)` for states), etc.}
│ ...
├ info: a dictionary of metadata on the dataset
│ ├ fps: float - frame per second the dataset is recorded/synchronized to
│ └ video: bool - indicates if frames are encoded in mp4 video files to save space or stored as png files
├ videos_dir: path to where the mp4 videos or png images are stored/accessed
└ camera_keys: List of string: the keys to access camera features in the item returned by the dataset (e.g. `["observation.images.cam_high", ...]`)
```
A `LeRobotDataset` is serialised using several widespread file formats for each of its parts, namely:
- hf_dataset stored using Hugging Face datasets library serialization to parquet
- videos are stored in mp4 format to save space or png files
- episode_data_index saved using `safetensor` tensor serializtion format
- stats saved using `safetensor` tensor serializtion format
- info are saved using JSON
Dataset can uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can set the `DATA_DIR` environment variable to you root dataset folder as illustrated in the above section on dataset visualization.
### Evaluate a pretrained policy
Check out [example 2](./examples/2_evaluate_pretrained_policy.py) that illustrates how to download a pretrained policy from Hugging Face hub, and run an evaluation on its corresponding environment.

View File

@@ -70,7 +70,7 @@ python lerobot/scripts/train.py policy=act env=aloha
There are two things to note here:
- Config overrides are passed as `param_name=param_value`.
- Here we have overridden the defaults section. `policy=act` tells Hydra to use `policy/act.yaml`, and `env=aloha` tells Hydra to use `env/aloha.yaml`.
- Here we have overridden the defaults section. `policy=act` tells Hydra to use `policy/act.yaml`, and `env=aloha` tells Hydra to use `env/pusht.yaml`.
_As an aside: we've set up all of our configurations so that they reproduce state-of-the-art results from papers in the literature._

View File

@@ -1,89 +0,0 @@
# Using `lerobot` on a real world arm
In this example, we'll be using `lerobot` on a real world arm to:
- record a dataset in the `lerobot` format
- (soon) train a policy on it
- (soon) run the policy in the real-world
## Which robotic arm to use
In this example we're using the [open-source low-cost arm from Alexander Koch](https://github.com/AlexanderKoch-Koch/low_cost_robot) in the specific setup of:
- having 6 servos per arm, i.e. using the elbow-to-wrist extension
- adding two cameras around it, one on top and one in the front
- having a teleoperation arm as well (build the leader and the follower arms in A. Koch repo, both with elbow-to-wrist extensions)
I'm using these cameras (but the setup should not be sensitive to the exact cameras you're using):
- C922 Pro Stream Webcam
- Intel(R) RealSense D455 (using only the RGB input)
In general, this example should be very easily extendable to any type of arm using Dynamixel servos with at least one camera by changing a couple of configuration in the gym env.
## Install the example
Follow these steps:
- install `lerobot`
- install the Dynamixel-sdk: ` pip install dynamixel-sdk`
## Usage
### 0 - record examples
Run the `record_training_data.py` example, selecting the duration and number of episodes you want to record, e.g.
```
DATA_DIR='./data' python record_training_data.py \
--repo-id=thomwolf/blue_red_sort \
--num-episodes=50 \
--num-frames=400
```
TODO:
- various length episodes
- being able to drop episodes
- checking uploading to the hub
### 1 - visualize the dataset
Use the standard dataset visualization script pointing it to the right folder:
```
DATA_DIR='./data' python ../../lerobot/scripts/visualize_dataset.py \
--repo-id thomwolf/blue_red_sort \
--episode-index 0
```
### 2 - Train a policy
From the example directory let's run this command to train a model using ACT
```
DATA_DIR='./data' python ../../lerobot/scripts/train.py \
device=cuda \
hydra.searchpath=[file://./train_config/] \
hydra.run.dir=./outputs/train/blue_red_sort \
dataset_repo_id=thomwolf/blue_red_sort \
env=gym_real_world \
policy=act_real_world \
wandb.enable=false
```
### 3 - Evaluate the policy in the real world
From the example directory let's run this command to evaluate our policy.
The configuration for running the policy is in the checkpoint of the model.
You can override parameters as follow:
```
python run_policy.py \
-p ./outputs/train/blue_red_sort/checkpoints/last/pretrained_model/
env.episode_length=1000
```
## Convert a hdf5 dataset recorded with the original ACT repo
You can convert a dataset from the raw data format of HDF5 files like in: https://github.com/tonyzhaozh/act with the following command:
```
python ./lerobot/scripts/push_dataset_to_hub.py
```

View File

@@ -1,840 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from safetensors.torch import load_file, save_file\n",
"from pprint import pprint"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
"original_ckpt_path = \"/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort/policy_last.ckpt\"\n",
"converted_ckpt_path = \"/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort/model.safetensors\"\n",
"\n",
"comparison_main_path = \"/home/thomwolf/Documents/Github/lerobot/examples/real_robot_example/outputs/train/blue_red_debug_no_masking/checkpoints/last/pretrained_model/\"\n",
"comparison_safetensor_path = comparison_main_path + \"model.safetensors\"\n",
"comparison_config_json_path = comparison_main_path + \"config.json\"\n",
"comparison_config_yaml_path = comparison_main_path + \"config.yaml\""
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"a = torch.load(original_ckpt_path)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"b = load_file(comparison_safetensor_path)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['model.action_head.bias',\n",
" 'model.action_head.weight',\n",
" 'model.backbone.bn1.bias',\n",
" 'model.backbone.bn1.running_mean',\n",
" 'model.backbone.bn1.running_var',\n",
" 'model.backbone.bn1.weight',\n",
" 'model.backbone.conv1.weight',\n",
" 'model.backbone.layer1.0.bn1.bias',\n",
" 'model.backbone.layer1.0.bn1.running_mean',\n",
" 'model.backbone.layer1.0.bn1.running_var',\n",
" 'model.backbone.layer1.0.bn1.weight',\n",
" 'model.backbone.layer1.0.bn2.bias',\n",
" 'model.backbone.layer1.0.bn2.running_mean',\n",
" 'model.backbone.layer1.0.bn2.running_var',\n",
" 'model.backbone.layer1.0.bn2.weight',\n",
" 'model.backbone.layer1.0.conv1.weight',\n",
" 'model.backbone.layer1.0.conv2.weight',\n",
" 'model.backbone.layer1.1.bn1.bias',\n",
" 'model.backbone.layer1.1.bn1.running_mean',\n",
" 'model.backbone.layer1.1.bn1.running_var',\n",
" 'model.backbone.layer1.1.bn1.weight',\n",
" 'model.backbone.layer1.1.bn2.bias',\n",
" 'model.backbone.layer1.1.bn2.running_mean',\n",
" 'model.backbone.layer1.1.bn2.running_var',\n",
" 'model.backbone.layer1.1.bn2.weight',\n",
" 'model.backbone.layer1.1.conv1.weight',\n",
" 'model.backbone.layer1.1.conv2.weight',\n",
" 'model.backbone.layer2.0.bn1.bias',\n",
" 'model.backbone.layer2.0.bn1.running_mean',\n",
" 'model.backbone.layer2.0.bn1.running_var',\n",
" 'model.backbone.layer2.0.bn1.weight',\n",
" 'model.backbone.layer2.0.bn2.bias',\n",
" 'model.backbone.layer2.0.bn2.running_mean',\n",
" 'model.backbone.layer2.0.bn2.running_var',\n",
" 'model.backbone.layer2.0.bn2.weight',\n",
" 'model.backbone.layer2.0.conv1.weight',\n",
" 'model.backbone.layer2.0.conv2.weight',\n",
" 'model.backbone.layer2.0.downsample.0.weight',\n",
" 'model.backbone.layer2.0.downsample.1.bias',\n",
" 'model.backbone.layer2.0.downsample.1.running_mean',\n",
" 'model.backbone.layer2.0.downsample.1.running_var',\n",
" 'model.backbone.layer2.0.downsample.1.weight',\n",
" 'model.backbone.layer2.1.bn1.bias',\n",
" 'model.backbone.layer2.1.bn1.running_mean',\n",
" 'model.backbone.layer2.1.bn1.running_var',\n",
" 'model.backbone.layer2.1.bn1.weight',\n",
" 'model.backbone.layer2.1.bn2.bias',\n",
" 'model.backbone.layer2.1.bn2.running_mean',\n",
" 'model.backbone.layer2.1.bn2.running_var',\n",
" 'model.backbone.layer2.1.bn2.weight',\n",
" 'model.backbone.layer2.1.conv1.weight',\n",
" 'model.backbone.layer2.1.conv2.weight',\n",
" 'model.backbone.layer3.0.bn1.bias',\n",
" 'model.backbone.layer3.0.bn1.running_mean',\n",
" 'model.backbone.layer3.0.bn1.running_var',\n",
" 'model.backbone.layer3.0.bn1.weight',\n",
" 'model.backbone.layer3.0.bn2.bias',\n",
" 'model.backbone.layer3.0.bn2.running_mean',\n",
" 'model.backbone.layer3.0.bn2.running_var',\n",
" 'model.backbone.layer3.0.bn2.weight',\n",
" 'model.backbone.layer3.0.conv1.weight',\n",
" 'model.backbone.layer3.0.conv2.weight',\n",
" 'model.backbone.layer3.0.downsample.0.weight',\n",
" 'model.backbone.layer3.0.downsample.1.bias',\n",
" 'model.backbone.layer3.0.downsample.1.running_mean',\n",
" 'model.backbone.layer3.0.downsample.1.running_var',\n",
" 'model.backbone.layer3.0.downsample.1.weight',\n",
" 'model.backbone.layer3.1.bn1.bias',\n",
" 'model.backbone.layer3.1.bn1.running_mean',\n",
" 'model.backbone.layer3.1.bn1.running_var',\n",
" 'model.backbone.layer3.1.bn1.weight',\n",
" 'model.backbone.layer3.1.bn2.bias',\n",
" 'model.backbone.layer3.1.bn2.running_mean',\n",
" 'model.backbone.layer3.1.bn2.running_var',\n",
" 'model.backbone.layer3.1.bn2.weight',\n",
" 'model.backbone.layer3.1.conv1.weight',\n",
" 'model.backbone.layer3.1.conv2.weight',\n",
" 'model.backbone.layer4.0.bn1.bias',\n",
" 'model.backbone.layer4.0.bn1.running_mean',\n",
" 'model.backbone.layer4.0.bn1.running_var',\n",
" 'model.backbone.layer4.0.bn1.weight',\n",
" 'model.backbone.layer4.0.bn2.bias',\n",
" 'model.backbone.layer4.0.bn2.running_mean',\n",
" 'model.backbone.layer4.0.bn2.running_var',\n",
" 'model.backbone.layer4.0.bn2.weight',\n",
" 'model.backbone.layer4.0.conv1.weight',\n",
" 'model.backbone.layer4.0.conv2.weight',\n",
" 'model.backbone.layer4.0.downsample.0.weight',\n",
" 'model.backbone.layer4.0.downsample.1.bias',\n",
" 'model.backbone.layer4.0.downsample.1.running_mean',\n",
" 'model.backbone.layer4.0.downsample.1.running_var',\n",
" 'model.backbone.layer4.0.downsample.1.weight',\n",
" 'model.backbone.layer4.1.bn1.bias',\n",
" 'model.backbone.layer4.1.bn1.running_mean',\n",
" 'model.backbone.layer4.1.bn1.running_var',\n",
" 'model.backbone.layer4.1.bn1.weight',\n",
" 'model.backbone.layer4.1.bn2.bias',\n",
" 'model.backbone.layer4.1.bn2.running_mean',\n",
" 'model.backbone.layer4.1.bn2.running_var',\n",
" 'model.backbone.layer4.1.bn2.weight',\n",
" 'model.backbone.layer4.1.conv1.weight',\n",
" 'model.backbone.layer4.1.conv2.weight',\n",
" 'model.decoder.layers.0.linear1.bias',\n",
" 'model.decoder.layers.0.linear1.weight',\n",
" 'model.decoder.layers.0.linear2.bias',\n",
" 'model.decoder.layers.0.linear2.weight',\n",
" 'model.decoder.layers.0.multihead_attn.in_proj_bias',\n",
" 'model.decoder.layers.0.multihead_attn.in_proj_weight',\n",
" 'model.decoder.layers.0.multihead_attn.out_proj.bias',\n",
" 'model.decoder.layers.0.multihead_attn.out_proj.weight',\n",
" 'model.decoder.layers.0.norm1.bias',\n",
" 'model.decoder.layers.0.norm1.weight',\n",
" 'model.decoder.layers.0.norm2.bias',\n",
" 'model.decoder.layers.0.norm2.weight',\n",
" 'model.decoder.layers.0.norm3.bias',\n",
" 'model.decoder.layers.0.norm3.weight',\n",
" 'model.decoder.layers.0.self_attn.in_proj_bias',\n",
" 'model.decoder.layers.0.self_attn.in_proj_weight',\n",
" 'model.decoder.layers.0.self_attn.out_proj.bias',\n",
" 'model.decoder.layers.0.self_attn.out_proj.weight',\n",
" 'model.decoder_pos_embed.weight',\n",
" 'model.encoder.layers.0.linear1.bias',\n",
" 'model.encoder.layers.0.linear1.weight',\n",
" 'model.encoder.layers.0.linear2.bias',\n",
" 'model.encoder.layers.0.linear2.weight',\n",
" 'model.encoder.layers.0.norm1.bias',\n",
" 'model.encoder.layers.0.norm1.weight',\n",
" 'model.encoder.layers.0.norm2.bias',\n",
" 'model.encoder.layers.0.norm2.weight',\n",
" 'model.encoder.layers.0.self_attn.in_proj_bias',\n",
" 'model.encoder.layers.0.self_attn.in_proj_weight',\n",
" 'model.encoder.layers.0.self_attn.out_proj.bias',\n",
" 'model.encoder.layers.0.self_attn.out_proj.weight',\n",
" 'model.encoder.layers.1.linear1.bias',\n",
" 'model.encoder.layers.1.linear1.weight',\n",
" 'model.encoder.layers.1.linear2.bias',\n",
" 'model.encoder.layers.1.linear2.weight',\n",
" 'model.encoder.layers.1.norm1.bias',\n",
" 'model.encoder.layers.1.norm1.weight',\n",
" 'model.encoder.layers.1.norm2.bias',\n",
" 'model.encoder.layers.1.norm2.weight',\n",
" 'model.encoder.layers.1.self_attn.in_proj_bias',\n",
" 'model.encoder.layers.1.self_attn.in_proj_weight',\n",
" 'model.encoder.layers.1.self_attn.out_proj.bias',\n",
" 'model.encoder.layers.1.self_attn.out_proj.weight',\n",
" 'model.encoder.layers.2.linear1.bias',\n",
" 'model.encoder.layers.2.linear1.weight',\n",
" 'model.encoder.layers.2.linear2.bias',\n",
" 'model.encoder.layers.2.linear2.weight',\n",
" 'model.encoder.layers.2.norm1.bias',\n",
" 'model.encoder.layers.2.norm1.weight',\n",
" 'model.encoder.layers.2.norm2.bias',\n",
" 'model.encoder.layers.2.norm2.weight',\n",
" 'model.encoder.layers.2.self_attn.in_proj_bias',\n",
" 'model.encoder.layers.2.self_attn.in_proj_weight',\n",
" 'model.encoder.layers.2.self_attn.out_proj.bias',\n",
" 'model.encoder.layers.2.self_attn.out_proj.weight',\n",
" 'model.encoder.layers.3.linear1.bias',\n",
" 'model.encoder.layers.3.linear1.weight',\n",
" 'model.encoder.layers.3.linear2.bias',\n",
" 'model.encoder.layers.3.linear2.weight',\n",
" 'model.encoder.layers.3.norm1.bias',\n",
" 'model.encoder.layers.3.norm1.weight',\n",
" 'model.encoder.layers.3.norm2.bias',\n",
" 'model.encoder.layers.3.norm2.weight',\n",
" 'model.encoder.layers.3.self_attn.in_proj_bias',\n",
" 'model.encoder.layers.3.self_attn.in_proj_weight',\n",
" 'model.encoder.layers.3.self_attn.out_proj.bias',\n",
" 'model.encoder.layers.3.self_attn.out_proj.weight',\n",
" 'model.encoder_img_feat_input_proj.bias',\n",
" 'model.encoder_img_feat_input_proj.weight',\n",
" 'model.encoder_latent_input_proj.bias',\n",
" 'model.encoder_latent_input_proj.weight',\n",
" 'model.encoder_robot_and_latent_pos_embed.weight',\n",
" 'model.encoder_robot_state_input_proj.bias',\n",
" 'model.encoder_robot_state_input_proj.weight',\n",
" 'model.vae_encoder.layers.0.linear1.bias',\n",
" 'model.vae_encoder.layers.0.linear1.weight',\n",
" 'model.vae_encoder.layers.0.linear2.bias',\n",
" 'model.vae_encoder.layers.0.linear2.weight',\n",
" 'model.vae_encoder.layers.0.norm1.bias',\n",
" 'model.vae_encoder.layers.0.norm1.weight',\n",
" 'model.vae_encoder.layers.0.norm2.bias',\n",
" 'model.vae_encoder.layers.0.norm2.weight',\n",
" 'model.vae_encoder.layers.0.self_attn.in_proj_bias',\n",
" 'model.vae_encoder.layers.0.self_attn.in_proj_weight',\n",
" 'model.vae_encoder.layers.0.self_attn.out_proj.bias',\n",
" 'model.vae_encoder.layers.0.self_attn.out_proj.weight',\n",
" 'model.vae_encoder.layers.1.linear1.bias',\n",
" 'model.vae_encoder.layers.1.linear1.weight',\n",
" 'model.vae_encoder.layers.1.linear2.bias',\n",
" 'model.vae_encoder.layers.1.linear2.weight',\n",
" 'model.vae_encoder.layers.1.norm1.bias',\n",
" 'model.vae_encoder.layers.1.norm1.weight',\n",
" 'model.vae_encoder.layers.1.norm2.bias',\n",
" 'model.vae_encoder.layers.1.norm2.weight',\n",
" 'model.vae_encoder.layers.1.self_attn.in_proj_bias',\n",
" 'model.vae_encoder.layers.1.self_attn.in_proj_weight',\n",
" 'model.vae_encoder.layers.1.self_attn.out_proj.bias',\n",
" 'model.vae_encoder.layers.1.self_attn.out_proj.weight',\n",
" 'model.vae_encoder.layers.2.linear1.bias',\n",
" 'model.vae_encoder.layers.2.linear1.weight',\n",
" 'model.vae_encoder.layers.2.linear2.bias',\n",
" 'model.vae_encoder.layers.2.linear2.weight',\n",
" 'model.vae_encoder.layers.2.norm1.bias',\n",
" 'model.vae_encoder.layers.2.norm1.weight',\n",
" 'model.vae_encoder.layers.2.norm2.bias',\n",
" 'model.vae_encoder.layers.2.norm2.weight',\n",
" 'model.vae_encoder.layers.2.self_attn.in_proj_bias',\n",
" 'model.vae_encoder.layers.2.self_attn.in_proj_weight',\n",
" 'model.vae_encoder.layers.2.self_attn.out_proj.bias',\n",
" 'model.vae_encoder.layers.2.self_attn.out_proj.weight',\n",
" 'model.vae_encoder.layers.3.linear1.bias',\n",
" 'model.vae_encoder.layers.3.linear1.weight',\n",
" 'model.vae_encoder.layers.3.linear2.bias',\n",
" 'model.vae_encoder.layers.3.linear2.weight',\n",
" 'model.vae_encoder.layers.3.norm1.bias',\n",
" 'model.vae_encoder.layers.3.norm1.weight',\n",
" 'model.vae_encoder.layers.3.norm2.bias',\n",
" 'model.vae_encoder.layers.3.norm2.weight',\n",
" 'model.vae_encoder.layers.3.self_attn.in_proj_bias',\n",
" 'model.vae_encoder.layers.3.self_attn.in_proj_weight',\n",
" 'model.vae_encoder.layers.3.self_attn.out_proj.bias',\n",
" 'model.vae_encoder.layers.3.self_attn.out_proj.weight',\n",
" 'model.vae_encoder_action_input_proj.bias',\n",
" 'model.vae_encoder_action_input_proj.weight',\n",
" 'model.vae_encoder_cls_embed.weight',\n",
" 'model.vae_encoder_latent_output_proj.bias',\n",
" 'model.vae_encoder_latent_output_proj.weight',\n",
" 'model.vae_encoder_pos_enc',\n",
" 'model.vae_encoder_robot_state_input_proj.bias',\n",
" 'model.vae_encoder_robot_state_input_proj.weight',\n",
" 'normalize_inputs.buffer_observation_images_front.mean',\n",
" 'normalize_inputs.buffer_observation_images_front.std',\n",
" 'normalize_inputs.buffer_observation_images_top.mean',\n",
" 'normalize_inputs.buffer_observation_images_top.std',\n",
" 'normalize_inputs.buffer_observation_state.mean',\n",
" 'normalize_inputs.buffer_observation_state.std',\n",
" 'normalize_targets.buffer_action.mean',\n",
" 'normalize_targets.buffer_action.std',\n",
" 'unnormalize_outputs.buffer_action.mean',\n",
" 'unnormalize_outputs.buffer_action.std']\n"
]
}
],
"source": [
"dest = list(b.keys())\n",
"pprint(dest)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['model.pos_table',\n",
" 'model.transformer.encoder.layers.0.self_attn.in_proj_weight',\n",
" 'model.transformer.encoder.layers.0.self_attn.in_proj_bias',\n",
" 'model.transformer.encoder.layers.0.self_attn.out_proj.weight',\n",
" 'model.transformer.encoder.layers.0.self_attn.out_proj.bias',\n",
" 'model.transformer.encoder.layers.0.linear1.weight',\n",
" 'model.transformer.encoder.layers.0.linear1.bias',\n",
" 'model.transformer.encoder.layers.0.linear2.weight',\n",
" 'model.transformer.encoder.layers.0.linear2.bias',\n",
" 'model.transformer.encoder.layers.0.norm1.weight',\n",
" 'model.transformer.encoder.layers.0.norm1.bias',\n",
" 'model.transformer.encoder.layers.0.norm2.weight',\n",
" 'model.transformer.encoder.layers.0.norm2.bias',\n",
" 'model.transformer.encoder.layers.1.self_attn.in_proj_weight',\n",
" 'model.transformer.encoder.layers.1.self_attn.in_proj_bias',\n",
" 'model.transformer.encoder.layers.1.self_attn.out_proj.weight',\n",
" 'model.transformer.encoder.layers.1.self_attn.out_proj.bias',\n",
" 'model.transformer.encoder.layers.1.linear1.weight',\n",
" 'model.transformer.encoder.layers.1.linear1.bias',\n",
" 'model.transformer.encoder.layers.1.linear2.weight',\n",
" 'model.transformer.encoder.layers.1.linear2.bias',\n",
" 'model.transformer.encoder.layers.1.norm1.weight',\n",
" 'model.transformer.encoder.layers.1.norm1.bias',\n",
" 'model.transformer.encoder.layers.1.norm2.weight',\n",
" 'model.transformer.encoder.layers.1.norm2.bias',\n",
" 'model.transformer.encoder.layers.2.self_attn.in_proj_weight',\n",
" 'model.transformer.encoder.layers.2.self_attn.in_proj_bias',\n",
" 'model.transformer.encoder.layers.2.self_attn.out_proj.weight',\n",
" 'model.transformer.encoder.layers.2.self_attn.out_proj.bias',\n",
" 'model.transformer.encoder.layers.2.linear1.weight',\n",
" 'model.transformer.encoder.layers.2.linear1.bias',\n",
" 'model.transformer.encoder.layers.2.linear2.weight',\n",
" 'model.transformer.encoder.layers.2.linear2.bias',\n",
" 'model.transformer.encoder.layers.2.norm1.weight',\n",
" 'model.transformer.encoder.layers.2.norm1.bias',\n",
" 'model.transformer.encoder.layers.2.norm2.weight',\n",
" 'model.transformer.encoder.layers.2.norm2.bias',\n",
" 'model.transformer.encoder.layers.3.self_attn.in_proj_weight',\n",
" 'model.transformer.encoder.layers.3.self_attn.in_proj_bias',\n",
" 'model.transformer.encoder.layers.3.self_attn.out_proj.weight',\n",
" 'model.transformer.encoder.layers.3.self_attn.out_proj.bias',\n",
" 'model.transformer.encoder.layers.3.linear1.weight',\n",
" 'model.transformer.encoder.layers.3.linear1.bias',\n",
" 'model.transformer.encoder.layers.3.linear2.weight',\n",
" 'model.transformer.encoder.layers.3.linear2.bias',\n",
" 'model.transformer.encoder.layers.3.norm1.weight',\n",
" 'model.transformer.encoder.layers.3.norm1.bias',\n",
" 'model.transformer.encoder.layers.3.norm2.weight',\n",
" 'model.transformer.encoder.layers.3.norm2.bias',\n",
" 'model.transformer.decoder.layers.0.self_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.0.self_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.0.self_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.0.self_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.0.multihead_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.0.multihead_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.0.multihead_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.0.multihead_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.0.linear1.weight',\n",
" 'model.transformer.decoder.layers.0.linear1.bias',\n",
" 'model.transformer.decoder.layers.0.linear2.weight',\n",
" 'model.transformer.decoder.layers.0.linear2.bias',\n",
" 'model.transformer.decoder.layers.0.norm1.weight',\n",
" 'model.transformer.decoder.layers.0.norm1.bias',\n",
" 'model.transformer.decoder.layers.0.norm2.weight',\n",
" 'model.transformer.decoder.layers.0.norm2.bias',\n",
" 'model.transformer.decoder.layers.0.norm3.weight',\n",
" 'model.transformer.decoder.layers.0.norm3.bias',\n",
" 'model.transformer.decoder.layers.1.self_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.1.self_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.1.self_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.1.self_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.1.multihead_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.1.multihead_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.1.multihead_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.1.multihead_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.1.linear1.weight',\n",
" 'model.transformer.decoder.layers.1.linear1.bias',\n",
" 'model.transformer.decoder.layers.1.linear2.weight',\n",
" 'model.transformer.decoder.layers.1.linear2.bias',\n",
" 'model.transformer.decoder.layers.1.norm1.weight',\n",
" 'model.transformer.decoder.layers.1.norm1.bias',\n",
" 'model.transformer.decoder.layers.1.norm2.weight',\n",
" 'model.transformer.decoder.layers.1.norm2.bias',\n",
" 'model.transformer.decoder.layers.1.norm3.weight',\n",
" 'model.transformer.decoder.layers.1.norm3.bias',\n",
" 'model.transformer.decoder.layers.2.self_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.2.self_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.2.self_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.2.self_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.2.multihead_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.2.multihead_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.2.multihead_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.2.multihead_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.2.linear1.weight',\n",
" 'model.transformer.decoder.layers.2.linear1.bias',\n",
" 'model.transformer.decoder.layers.2.linear2.weight',\n",
" 'model.transformer.decoder.layers.2.linear2.bias',\n",
" 'model.transformer.decoder.layers.2.norm1.weight',\n",
" 'model.transformer.decoder.layers.2.norm1.bias',\n",
" 'model.transformer.decoder.layers.2.norm2.weight',\n",
" 'model.transformer.decoder.layers.2.norm2.bias',\n",
" 'model.transformer.decoder.layers.2.norm3.weight',\n",
" 'model.transformer.decoder.layers.2.norm3.bias',\n",
" 'model.transformer.decoder.layers.3.self_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.3.self_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.3.self_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.3.self_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.3.multihead_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.3.multihead_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.3.multihead_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.3.multihead_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.3.linear1.weight',\n",
" 'model.transformer.decoder.layers.3.linear1.bias',\n",
" 'model.transformer.decoder.layers.3.linear2.weight',\n",
" 'model.transformer.decoder.layers.3.linear2.bias',\n",
" 'model.transformer.decoder.layers.3.norm1.weight',\n",
" 'model.transformer.decoder.layers.3.norm1.bias',\n",
" 'model.transformer.decoder.layers.3.norm2.weight',\n",
" 'model.transformer.decoder.layers.3.norm2.bias',\n",
" 'model.transformer.decoder.layers.3.norm3.weight',\n",
" 'model.transformer.decoder.layers.3.norm3.bias',\n",
" 'model.transformer.decoder.layers.4.self_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.4.self_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.4.self_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.4.self_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.4.multihead_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.4.multihead_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.4.multihead_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.4.multihead_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.4.linear1.weight',\n",
" 'model.transformer.decoder.layers.4.linear1.bias',\n",
" 'model.transformer.decoder.layers.4.linear2.weight',\n",
" 'model.transformer.decoder.layers.4.linear2.bias',\n",
" 'model.transformer.decoder.layers.4.norm1.weight',\n",
" 'model.transformer.decoder.layers.4.norm1.bias',\n",
" 'model.transformer.decoder.layers.4.norm2.weight',\n",
" 'model.transformer.decoder.layers.4.norm2.bias',\n",
" 'model.transformer.decoder.layers.4.norm3.weight',\n",
" 'model.transformer.decoder.layers.4.norm3.bias',\n",
" 'model.transformer.decoder.layers.5.self_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.5.self_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.5.self_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.5.self_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.5.multihead_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.5.multihead_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.5.multihead_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.5.multihead_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.5.linear1.weight',\n",
" 'model.transformer.decoder.layers.5.linear1.bias',\n",
" 'model.transformer.decoder.layers.5.linear2.weight',\n",
" 'model.transformer.decoder.layers.5.linear2.bias',\n",
" 'model.transformer.decoder.layers.5.norm1.weight',\n",
" 'model.transformer.decoder.layers.5.norm1.bias',\n",
" 'model.transformer.decoder.layers.5.norm2.weight',\n",
" 'model.transformer.decoder.layers.5.norm2.bias',\n",
" 'model.transformer.decoder.layers.5.norm3.weight',\n",
" 'model.transformer.decoder.layers.5.norm3.bias',\n",
" 'model.transformer.decoder.layers.6.self_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.6.self_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.6.self_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.6.self_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.6.multihead_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.6.multihead_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.6.multihead_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.6.multihead_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.6.linear1.weight',\n",
" 'model.transformer.decoder.layers.6.linear1.bias',\n",
" 'model.transformer.decoder.layers.6.linear2.weight',\n",
" 'model.transformer.decoder.layers.6.linear2.bias',\n",
" 'model.transformer.decoder.layers.6.norm1.weight',\n",
" 'model.transformer.decoder.layers.6.norm1.bias',\n",
" 'model.transformer.decoder.layers.6.norm2.weight',\n",
" 'model.transformer.decoder.layers.6.norm2.bias',\n",
" 'model.transformer.decoder.layers.6.norm3.weight',\n",
" 'model.transformer.decoder.layers.6.norm3.bias',\n",
" 'model.transformer.decoder.norm.weight',\n",
" 'model.transformer.decoder.norm.bias',\n",
" 'model.encoder.layers.0.self_attn.in_proj_weight',\n",
" 'model.encoder.layers.0.self_attn.in_proj_bias',\n",
" 'model.encoder.layers.0.self_attn.out_proj.weight',\n",
" 'model.encoder.layers.0.self_attn.out_proj.bias',\n",
" 'model.encoder.layers.0.linear1.weight',\n",
" 'model.encoder.layers.0.linear1.bias',\n",
" 'model.encoder.layers.0.linear2.weight',\n",
" 'model.encoder.layers.0.linear2.bias',\n",
" 'model.encoder.layers.0.norm1.weight',\n",
" 'model.encoder.layers.0.norm1.bias',\n",
" 'model.encoder.layers.0.norm2.weight',\n",
" 'model.encoder.layers.0.norm2.bias',\n",
" 'model.encoder.layers.1.self_attn.in_proj_weight',\n",
" 'model.encoder.layers.1.self_attn.in_proj_bias',\n",
" 'model.encoder.layers.1.self_attn.out_proj.weight',\n",
" 'model.encoder.layers.1.self_attn.out_proj.bias',\n",
" 'model.encoder.layers.1.linear1.weight',\n",
" 'model.encoder.layers.1.linear1.bias',\n",
" 'model.encoder.layers.1.linear2.weight',\n",
" 'model.encoder.layers.1.linear2.bias',\n",
" 'model.encoder.layers.1.norm1.weight',\n",
" 'model.encoder.layers.1.norm1.bias',\n",
" 'model.encoder.layers.1.norm2.weight',\n",
" 'model.encoder.layers.1.norm2.bias',\n",
" 'model.encoder.layers.2.self_attn.in_proj_weight',\n",
" 'model.encoder.layers.2.self_attn.in_proj_bias',\n",
" 'model.encoder.layers.2.self_attn.out_proj.weight',\n",
" 'model.encoder.layers.2.self_attn.out_proj.bias',\n",
" 'model.encoder.layers.2.linear1.weight',\n",
" 'model.encoder.layers.2.linear1.bias',\n",
" 'model.encoder.layers.2.linear2.weight',\n",
" 'model.encoder.layers.2.linear2.bias',\n",
" 'model.encoder.layers.2.norm1.weight',\n",
" 'model.encoder.layers.2.norm1.bias',\n",
" 'model.encoder.layers.2.norm2.weight',\n",
" 'model.encoder.layers.2.norm2.bias',\n",
" 'model.encoder.layers.3.self_attn.in_proj_weight',\n",
" 'model.encoder.layers.3.self_attn.in_proj_bias',\n",
" 'model.encoder.layers.3.self_attn.out_proj.weight',\n",
" 'model.encoder.layers.3.self_attn.out_proj.bias',\n",
" 'model.encoder.layers.3.linear1.weight',\n",
" 'model.encoder.layers.3.linear1.bias',\n",
" 'model.encoder.layers.3.linear2.weight',\n",
" 'model.encoder.layers.3.linear2.bias',\n",
" 'model.encoder.layers.3.norm1.weight',\n",
" 'model.encoder.layers.3.norm1.bias',\n",
" 'model.encoder.layers.3.norm2.weight',\n",
" 'model.encoder.layers.3.norm2.bias',\n",
" 'model.action_head.weight',\n",
" 'model.action_head.bias',\n",
" 'model.is_pad_head.weight',\n",
" 'model.is_pad_head.bias',\n",
" 'model.query_embed.weight',\n",
" 'model.input_proj.weight',\n",
" 'model.input_proj.bias',\n",
" 'model.backbones.0.0.body.conv1.weight',\n",
" 'model.backbones.0.0.body.bn1.weight',\n",
" 'model.backbones.0.0.body.bn1.bias',\n",
" 'model.backbones.0.0.body.bn1.running_mean',\n",
" 'model.backbones.0.0.body.bn1.running_var',\n",
" 'model.backbones.0.0.body.bn1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer1.0.conv1.weight',\n",
" 'model.backbones.0.0.body.layer1.0.bn1.weight',\n",
" 'model.backbones.0.0.body.layer1.0.bn1.bias',\n",
" 'model.backbones.0.0.body.layer1.0.bn1.running_mean',\n",
" 'model.backbones.0.0.body.layer1.0.bn1.running_var',\n",
" 'model.backbones.0.0.body.layer1.0.bn1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer1.0.conv2.weight',\n",
" 'model.backbones.0.0.body.layer1.0.bn2.weight',\n",
" 'model.backbones.0.0.body.layer1.0.bn2.bias',\n",
" 'model.backbones.0.0.body.layer1.0.bn2.running_mean',\n",
" 'model.backbones.0.0.body.layer1.0.bn2.running_var',\n",
" 'model.backbones.0.0.body.layer1.0.bn2.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer1.1.conv1.weight',\n",
" 'model.backbones.0.0.body.layer1.1.bn1.weight',\n",
" 'model.backbones.0.0.body.layer1.1.bn1.bias',\n",
" 'model.backbones.0.0.body.layer1.1.bn1.running_mean',\n",
" 'model.backbones.0.0.body.layer1.1.bn1.running_var',\n",
" 'model.backbones.0.0.body.layer1.1.bn1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer1.1.conv2.weight',\n",
" 'model.backbones.0.0.body.layer1.1.bn2.weight',\n",
" 'model.backbones.0.0.body.layer1.1.bn2.bias',\n",
" 'model.backbones.0.0.body.layer1.1.bn2.running_mean',\n",
" 'model.backbones.0.0.body.layer1.1.bn2.running_var',\n",
" 'model.backbones.0.0.body.layer1.1.bn2.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer2.0.conv1.weight',\n",
" 'model.backbones.0.0.body.layer2.0.bn1.weight',\n",
" 'model.backbones.0.0.body.layer2.0.bn1.bias',\n",
" 'model.backbones.0.0.body.layer2.0.bn1.running_mean',\n",
" 'model.backbones.0.0.body.layer2.0.bn1.running_var',\n",
" 'model.backbones.0.0.body.layer2.0.bn1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer2.0.conv2.weight',\n",
" 'model.backbones.0.0.body.layer2.0.bn2.weight',\n",
" 'model.backbones.0.0.body.layer2.0.bn2.bias',\n",
" 'model.backbones.0.0.body.layer2.0.bn2.running_mean',\n",
" 'model.backbones.0.0.body.layer2.0.bn2.running_var',\n",
" 'model.backbones.0.0.body.layer2.0.bn2.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer2.0.downsample.0.weight',\n",
" 'model.backbones.0.0.body.layer2.0.downsample.1.weight',\n",
" 'model.backbones.0.0.body.layer2.0.downsample.1.bias',\n",
" 'model.backbones.0.0.body.layer2.0.downsample.1.running_mean',\n",
" 'model.backbones.0.0.body.layer2.0.downsample.1.running_var',\n",
" 'model.backbones.0.0.body.layer2.0.downsample.1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer2.1.conv1.weight',\n",
" 'model.backbones.0.0.body.layer2.1.bn1.weight',\n",
" 'model.backbones.0.0.body.layer2.1.bn1.bias',\n",
" 'model.backbones.0.0.body.layer2.1.bn1.running_mean',\n",
" 'model.backbones.0.0.body.layer2.1.bn1.running_var',\n",
" 'model.backbones.0.0.body.layer2.1.bn1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer2.1.conv2.weight',\n",
" 'model.backbones.0.0.body.layer2.1.bn2.weight',\n",
" 'model.backbones.0.0.body.layer2.1.bn2.bias',\n",
" 'model.backbones.0.0.body.layer2.1.bn2.running_mean',\n",
" 'model.backbones.0.0.body.layer2.1.bn2.running_var',\n",
" 'model.backbones.0.0.body.layer2.1.bn2.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer3.0.conv1.weight',\n",
" 'model.backbones.0.0.body.layer3.0.bn1.weight',\n",
" 'model.backbones.0.0.body.layer3.0.bn1.bias',\n",
" 'model.backbones.0.0.body.layer3.0.bn1.running_mean',\n",
" 'model.backbones.0.0.body.layer3.0.bn1.running_var',\n",
" 'model.backbones.0.0.body.layer3.0.bn1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer3.0.conv2.weight',\n",
" 'model.backbones.0.0.body.layer3.0.bn2.weight',\n",
" 'model.backbones.0.0.body.layer3.0.bn2.bias',\n",
" 'model.backbones.0.0.body.layer3.0.bn2.running_mean',\n",
" 'model.backbones.0.0.body.layer3.0.bn2.running_var',\n",
" 'model.backbones.0.0.body.layer3.0.bn2.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer3.0.downsample.0.weight',\n",
" 'model.backbones.0.0.body.layer3.0.downsample.1.weight',\n",
" 'model.backbones.0.0.body.layer3.0.downsample.1.bias',\n",
" 'model.backbones.0.0.body.layer3.0.downsample.1.running_mean',\n",
" 'model.backbones.0.0.body.layer3.0.downsample.1.running_var',\n",
" 'model.backbones.0.0.body.layer3.0.downsample.1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer3.1.conv1.weight',\n",
" 'model.backbones.0.0.body.layer3.1.bn1.weight',\n",
" 'model.backbones.0.0.body.layer3.1.bn1.bias',\n",
" 'model.backbones.0.0.body.layer3.1.bn1.running_mean',\n",
" 'model.backbones.0.0.body.layer3.1.bn1.running_var',\n",
" 'model.backbones.0.0.body.layer3.1.bn1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer3.1.conv2.weight',\n",
" 'model.backbones.0.0.body.layer3.1.bn2.weight',\n",
" 'model.backbones.0.0.body.layer3.1.bn2.bias',\n",
" 'model.backbones.0.0.body.layer3.1.bn2.running_mean',\n",
" 'model.backbones.0.0.body.layer3.1.bn2.running_var',\n",
" 'model.backbones.0.0.body.layer3.1.bn2.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer4.0.conv1.weight',\n",
" 'model.backbones.0.0.body.layer4.0.bn1.weight',\n",
" 'model.backbones.0.0.body.layer4.0.bn1.bias',\n",
" 'model.backbones.0.0.body.layer4.0.bn1.running_mean',\n",
" 'model.backbones.0.0.body.layer4.0.bn1.running_var',\n",
" 'model.backbones.0.0.body.layer4.0.bn1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer4.0.conv2.weight',\n",
" 'model.backbones.0.0.body.layer4.0.bn2.weight',\n",
" 'model.backbones.0.0.body.layer4.0.bn2.bias',\n",
" 'model.backbones.0.0.body.layer4.0.bn2.running_mean',\n",
" 'model.backbones.0.0.body.layer4.0.bn2.running_var',\n",
" 'model.backbones.0.0.body.layer4.0.bn2.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer4.0.downsample.0.weight',\n",
" 'model.backbones.0.0.body.layer4.0.downsample.1.weight',\n",
" 'model.backbones.0.0.body.layer4.0.downsample.1.bias',\n",
" 'model.backbones.0.0.body.layer4.0.downsample.1.running_mean',\n",
" 'model.backbones.0.0.body.layer4.0.downsample.1.running_var',\n",
" 'model.backbones.0.0.body.layer4.0.downsample.1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer4.1.conv1.weight',\n",
" 'model.backbones.0.0.body.layer4.1.bn1.weight',\n",
" 'model.backbones.0.0.body.layer4.1.bn1.bias',\n",
" 'model.backbones.0.0.body.layer4.1.bn1.running_mean',\n",
" 'model.backbones.0.0.body.layer4.1.bn1.running_var',\n",
" 'model.backbones.0.0.body.layer4.1.bn1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer4.1.conv2.weight',\n",
" 'model.backbones.0.0.body.layer4.1.bn2.weight',\n",
" 'model.backbones.0.0.body.layer4.1.bn2.bias',\n",
" 'model.backbones.0.0.body.layer4.1.bn2.running_mean',\n",
" 'model.backbones.0.0.body.layer4.1.bn2.running_var',\n",
" 'model.backbones.0.0.body.layer4.1.bn2.num_batches_tracked',\n",
" 'model.input_proj_robot_state.weight',\n",
" 'model.input_proj_robot_state.bias',\n",
" 'model.cls_embed.weight',\n",
" 'model.encoder_action_proj.weight',\n",
" 'model.encoder_action_proj.bias',\n",
" 'model.encoder_joint_proj.weight',\n",
" 'model.encoder_joint_proj.bias',\n",
" 'model.latent_proj.weight',\n",
" 'model.latent_proj.bias',\n",
" 'model.latent_out_proj.weight',\n",
" 'model.latent_out_proj.bias',\n",
" 'model.additional_pos_embed.weight']\n"
]
}
],
"source": [
"orig = list(a.keys())\n",
"pprint(orig)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"a = torch.load(original_ckpt_path)\n",
"\n",
"to_remove_startswith = ['model.transformer.decoder.layers.1.',\n",
" 'model.transformer.decoder.layers.2.',\n",
" 'model.transformer.decoder.layers.3.',\n",
" 'model.transformer.decoder.layers.4.',\n",
" 'model.transformer.decoder.layers.5.',\n",
" 'model.transformer.decoder.layers.6.',\n",
" 'model.transformer.decoder.norm.',\n",
" 'model.is_pad_head']\n",
"\n",
"to_remove_in = ['num_batches_tracked',]\n",
"\n",
"conv = {}\n",
"\n",
"keys = list(a.keys())\n",
"for k in keys:\n",
" if any(k.startswith(tr) for tr in to_remove_startswith):\n",
" a.pop(k)\n",
" continue\n",
" if any(tr in k for tr in to_remove_in):\n",
" a.pop(k)\n",
" continue\n",
" if k.startswith('model.transformer.encoder.layers.'):\n",
" conv[k.replace('transformer.', '')] = a.pop(k)\n",
" if k.startswith('model.transformer.decoder.layers.0.'):\n",
" conv[k.replace('transformer.', '')] = a.pop(k)\n",
" if k.startswith('model.encoder.layers.'):\n",
" conv[k.replace('encoder.', 'vae_encoder.')] = a.pop(k)\n",
" if k.startswith('model.action_head.'):\n",
" conv[k] = a.pop(k)\n",
" if k.startswith('model.pos_table'):\n",
" conv[k.replace('pos_table', 'vae_encoder_pos_enc')] = a.pop(k)\n",
" if k.startswith('model.query_embed.'):\n",
" conv[k.replace('query_embed', 'decoder_pos_embed')] = a.pop(k)\n",
" if k.startswith('model.input_proj.'):\n",
" conv[k.replace('input_proj.', 'encoder_img_feat_input_proj.')] = a.pop(k)\n",
" if k.startswith('model.input_proj_robot_state.'):\n",
" conv[k.replace('input_proj_robot_state.', 'encoder_robot_state_input_proj.')] = a.pop(k)\n",
" if k.startswith('model.backbones.0.0.body.'):\n",
" conv[k.replace('backbones.0.0.body', 'backbone')] = a.pop(k)\n",
" if k.startswith('model.cls_embed.'):\n",
" conv[k.replace('cls_embed', 'vae_encoder_cls_embed')] = a.pop(k)\n",
" if k.startswith('model.encoder_action_proj.'):\n",
" conv[k.replace('encoder_action_proj', 'vae_encoder_action_input_proj')] = a.pop(k)\n",
" if k.startswith('model.encoder_joint_proj.'):\n",
" conv[k.replace('encoder_joint_proj', 'vae_encoder_robot_state_input_proj')] = a.pop(k)\n",
" if k.startswith('model.latent_proj.'):\n",
" conv[k.replace('latent_proj', 'vae_encoder_latent_output_proj')] = a.pop(k)\n",
" if k.startswith('model.latent_out_proj.'):\n",
" conv[k.replace('latent_out_proj', 'encoder_latent_input_proj')] = a.pop(k)\n",
" if k.startswith('model.additional_pos_embed.'):\n",
" conv[k.replace('additional_pos_embed', 'encoder_robot_and_latent_pos_embed')] = a.pop(k)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"OrderedDict()"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
"for k, v in conv.items():\n",
" assert b[k].shape == v.shape\n",
" b[k] = v"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
"save_file(b, converted_ckpt_path)"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort/config.yaml'"
]
},
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Now also copy the config files\n",
"import shutil\n",
"shutil.copy(comparison_config_json_path, converted_ckpt_path.replace('model.safetensors', 'config.json'))\n",
"shutil.copy(comparison_config_yaml_path, converted_ckpt_path.replace('model.safetensors', 'config.yaml'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "lerobot",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -1,8 +0,0 @@
from gymnasium.envs.registration import register
register(
id="gym_real_world/RealEnv-v0",
entry_point="gym_real_world.gym_environment:RealEnv",
max_episode_steps=300,
nondeterministic=True,
)

View File

@@ -1,363 +0,0 @@
# ruff: noqa
"""From Alexander Koch low_cost_robot codebase at https://github.com/AlexanderKoch-Koch/low_cost_robot
Dynamixel class to control the dynamixel servos
"""
from __future__ import annotations
import enum
import math
import os
from dataclasses import dataclass
import numpy as np
from dynamixel_sdk import * # Uses Dynamixel SDK library
def pos2pwm(pos: np.ndarray) -> np.ndarray:
"""
:param pos: numpy array of joint positions in range [-pi, pi]
:return: numpy array of pwm values in range [0, 4096]
"""
return ((pos / 3.14 + 1.0) * 2048).astype(np.int64)
def pwm2pos(pwm: np.ndarray) -> np.ndarray:
"""
:param pwm: numpy array of pwm values in range [0, 4096]
:return: numpy array of joint positions in range [-pi, pi]
"""
return (pwm / 2048 - 1) * 3.14
def pwm2vel(pwm: np.ndarray) -> np.ndarray:
"""
:param pwm: numpy array of pwm/s joint velocities
:return: numpy array of rad/s joint velocities
"""
return pwm * 3.14 / 2048
def vel2pwm(vel: np.ndarray) -> np.ndarray:
"""
:param vel: numpy array of rad/s joint velocities
:return: numpy array of pwm/s joint velocities
"""
return (vel * 2048 / 3.14).astype(np.int64)
class ReadAttribute(enum.Enum):
TEMPERATURE = 146
VOLTAGE = 145
VELOCITY = 128
POSITION = 132
CURRENT = 126
PWM = 124
HARDWARE_ERROR_STATUS = 70
HOMING_OFFSET = 20
BAUDRATE = 8
class OperatingMode(enum.Enum):
VELOCITY = 1
POSITION = 3
CURRENT_CONTROLLED_POSITION = 5
PWM = 16
UNKNOWN = -1
class Dynamixel:
ADDR_TORQUE_ENABLE = 64
ADDR_GOAL_POSITION = 116
ADDR_VELOCITY_LIMIT = 44
ADDR_GOAL_PWM = 100
OPERATING_MODE_ADDR = 11
POSITION_I = 82
POSITION_P = 84
ADDR_ID = 7
@dataclass
class Config:
def instantiate(self):
return Dynamixel(self)
baudrate: int = 57600
protocol_version: float = 2.0
device_name: str = "" # /dev/tty.usbserial-1120'
dynamixel_id: int = 1
def __init__(self, config: Config):
self.config = config
self.connect()
def connect(self):
if self.config.device_name == "":
for port_name in os.listdir("/dev"):
if "ttyUSB" in port_name or "ttyACM" in port_name:
self.config.device_name = "/dev/" + port_name
print(f"using device {self.config.device_name}")
self.portHandler = PortHandler(self.config.device_name)
# self.portHandler.LA
self.packetHandler = PacketHandler(self.config.protocol_version)
if not self.portHandler.openPort():
raise Exception(f"Failed to open port {self.config.device_name}")
if not self.portHandler.setBaudRate(self.config.baudrate):
raise Exception(f"failed to set baudrate to {self.config.baudrate}")
# self.operating_mode = OperatingMode.UNKNOWN
# self.torque_enabled = False
# self._disable_torque()
self.operating_modes = [None for _ in range(32)]
self.torque_enabled = [None for _ in range(32)]
return True
def disconnect(self):
self.portHandler.closePort()
def set_goal_position(self, motor_id, goal_position):
# if self.operating_modes[motor_id] is not OperatingMode.POSITION:
# self._disable_torque(motor_id)
# self.set_operating_mode(motor_id, OperatingMode.POSITION)
# if not self.torque_enabled[motor_id]:
# self._enable_torque(motor_id)
# self._enable_torque(motor_id)
dxl_comm_result, dxl_error = self.packetHandler.write4ByteTxRx(
self.portHandler, motor_id, self.ADDR_GOAL_POSITION, goal_position
)
# self._process_response(dxl_comm_result, dxl_error)
# print(f'set position of motor {motor_id} to {goal_position}')
def set_pwm_value(self, motor_id: int, pwm_value, tries=3):
if self.operating_modes[motor_id] is not OperatingMode.PWM:
self._disable_torque(motor_id)
self.set_operating_mode(motor_id, OperatingMode.PWM)
if not self.torque_enabled[motor_id]:
self._enable_torque(motor_id)
# print(f'enabling torque')
dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx(
self.portHandler, motor_id, self.ADDR_GOAL_PWM, pwm_value
)
# self._process_response(dxl_comm_result, dxl_error)
# print(f'set pwm of motor {motor_id} to {pwm_value}')
if dxl_comm_result != COMM_SUCCESS:
if tries <= 1:
raise ConnectionError(f"dxl_comm_result: {self.packetHandler.getTxRxResult(dxl_comm_result)}")
else:
print(f"dynamixel pwm setting failure trying again with {tries - 1} tries")
self.set_pwm_value(motor_id, pwm_value, tries=tries - 1)
elif dxl_error != 0:
print(f"dxl error {dxl_error}")
raise ConnectionError(f"dynamixel error: {self.packetHandler.getTxRxResult(dxl_error)}")
def read_temperature(self, motor_id: int):
return self._read_value(motor_id, ReadAttribute.TEMPERATURE, 1)
def read_velocity(self, motor_id: int):
pos = self._read_value(motor_id, ReadAttribute.VELOCITY, 4)
if pos > 2**31:
pos -= 2**32
# print(f'read position {pos} for motor {motor_id}')
return pos
def read_position(self, motor_id: int):
pos = self._read_value(motor_id, ReadAttribute.POSITION, 4)
if pos > 2**31:
pos -= 2**32
# print(f'read position {pos} for motor {motor_id}')
return pos
def read_position_degrees(self, motor_id: int) -> float:
return (self.read_position(motor_id) / 4096) * 360
def read_position_radians(self, motor_id: int) -> float:
return (self.read_position(motor_id) / 4096) * 2 * math.pi
def read_current(self, motor_id: int):
current = self._read_value(motor_id, ReadAttribute.CURRENT, 2)
if current > 2**15:
current -= 2**16
return current
def read_present_pwm(self, motor_id: int):
return self._read_value(motor_id, ReadAttribute.PWM, 2)
def read_hardware_error_status(self, motor_id: int):
return self._read_value(motor_id, ReadAttribute.HARDWARE_ERROR_STATUS, 1)
def disconnect(self):
self.portHandler.closePort()
def set_id(self, old_id, new_id, use_broadcast_id: bool = False):
"""
sets the id of the dynamixel servo
@param old_id: current id of the servo
@param new_id: new id
@param use_broadcast_id: set ids of all connected dynamixels if True.
If False, change only servo with self.config.id
@return:
"""
if use_broadcast_id:
current_id = 254
else:
current_id = old_id
dxl_comm_result, dxl_error = self.packetHandler.write1ByteTxRx(
self.portHandler, current_id, self.ADDR_ID, new_id
)
self._process_response(dxl_comm_result, dxl_error, old_id)
self.config.id = id
def _enable_torque(self, motor_id):
dxl_comm_result, dxl_error = self.packetHandler.write1ByteTxRx(
self.portHandler, motor_id, self.ADDR_TORQUE_ENABLE, 1
)
self._process_response(dxl_comm_result, dxl_error, motor_id)
self.torque_enabled[motor_id] = True
def _disable_torque(self, motor_id):
dxl_comm_result, dxl_error = self.packetHandler.write1ByteTxRx(
self.portHandler, motor_id, self.ADDR_TORQUE_ENABLE, 0
)
self._process_response(dxl_comm_result, dxl_error, motor_id)
self.torque_enabled[motor_id] = False
def _process_response(self, dxl_comm_result: int, dxl_error: int, motor_id: int):
if dxl_comm_result != COMM_SUCCESS:
raise ConnectionError(
f"dxl_comm_result for motor {motor_id}: {self.packetHandler.getTxRxResult(dxl_comm_result)}"
)
elif dxl_error != 0:
print(f"dxl error {dxl_error}")
raise ConnectionError(
f"dynamixel error for motor {motor_id}: {self.packetHandler.getTxRxResult(dxl_error)}"
)
def set_operating_mode(self, motor_id: int, operating_mode: OperatingMode):
dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx(
self.portHandler, motor_id, self.OPERATING_MODE_ADDR, operating_mode.value
)
self._process_response(dxl_comm_result, dxl_error, motor_id)
self.operating_modes[motor_id] = operating_mode
def set_pwm_limit(self, motor_id: int, limit: int):
dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx(self.portHandler, motor_id, 36, limit)
self._process_response(dxl_comm_result, dxl_error, motor_id)
def set_velocity_limit(self, motor_id: int, velocity_limit):
dxl_comm_result, dxl_error = self.packetHandler.write4ByteTxRx(
self.portHandler, motor_id, self.ADDR_VELOCITY_LIMIT, velocity_limit
)
self._process_response(dxl_comm_result, dxl_error, motor_id)
def set_P(self, motor_id: int, P: int):
dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx(
self.portHandler, motor_id, self.POSITION_P, P
)
self._process_response(dxl_comm_result, dxl_error, motor_id)
def set_I(self, motor_id: int, I: int):
dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx(
self.portHandler, motor_id, self.POSITION_I, I
)
self._process_response(dxl_comm_result, dxl_error, motor_id)
def read_home_offset(self, motor_id: int):
self._disable_torque(motor_id)
# dxl_comm_result, dxl_error = self.packetHandler.write4ByteTxRx(self.portHandler, motor_id,
# ReadAttribute.HOMING_OFFSET.value, home_position)
home_offset = self._read_value(motor_id, ReadAttribute.HOMING_OFFSET, 4)
# self._process_response(dxl_comm_result, dxl_error)
self._enable_torque(motor_id)
return home_offset
def set_home_offset(self, motor_id: int, home_position: int):
self._disable_torque(motor_id)
dxl_comm_result, dxl_error = self.packetHandler.write4ByteTxRx(
self.portHandler, motor_id, ReadAttribute.HOMING_OFFSET.value, home_position
)
self._process_response(dxl_comm_result, dxl_error, motor_id)
self._enable_torque(motor_id)
def set_baudrate(self, motor_id: int, baudrate):
# translate baudrate into dynamixel baudrate setting id
if baudrate == 57600:
baudrate_id = 1
elif baudrate == 1_000_000:
baudrate_id = 3
elif baudrate == 2_000_000:
baudrate_id = 4
elif baudrate == 3_000_000:
baudrate_id = 5
elif baudrate == 4_000_000:
baudrate_id = 6
else:
raise Exception("baudrate not implemented")
self._disable_torque(motor_id)
dxl_comm_result, dxl_error = self.packetHandler.write1ByteTxRx(
self.portHandler, motor_id, ReadAttribute.BAUDRATE.value, baudrate_id
)
self._process_response(dxl_comm_result, dxl_error, motor_id)
def _read_value(self, motor_id, attribute: ReadAttribute, num_bytes: int, tries=10):
try:
if num_bytes == 1:
value, dxl_comm_result, dxl_error = self.packetHandler.read1ByteTxRx(
self.portHandler, motor_id, attribute.value
)
elif num_bytes == 2:
value, dxl_comm_result, dxl_error = self.packetHandler.read2ByteTxRx(
self.portHandler, motor_id, attribute.value
)
elif num_bytes == 4:
value, dxl_comm_result, dxl_error = self.packetHandler.read4ByteTxRx(
self.portHandler, motor_id, attribute.value
)
except Exception:
if tries == 0:
raise Exception
else:
return self._read_value(motor_id, attribute, num_bytes, tries=tries - 1)
if dxl_comm_result != COMM_SUCCESS:
if tries <= 1:
# print("%s" % self.packetHandler.getTxRxResult(dxl_comm_result))
raise ConnectionError(f"dxl_comm_result {dxl_comm_result} for servo {motor_id} value {value}")
else:
print(f"dynamixel read failure for servo {motor_id} trying again with {tries - 1} tries")
time.sleep(0.02)
return self._read_value(motor_id, attribute, num_bytes, tries=tries - 1)
elif dxl_error != 0: # # print("%s" % self.packetHandler.getRxPacketError(dxl_error))
# raise ConnectionError(f'dxl_error {dxl_error} binary ' + "{0:b}".format(37))
if tries == 0 and dxl_error != 128:
raise Exception(f"Failed to read value from motor {motor_id} error is {dxl_error}")
else:
return self._read_value(motor_id, attribute, num_bytes, tries=tries - 1)
return value
def set_home_position(self, motor_id: int):
print(f"setting home position for motor {motor_id}")
self.set_home_offset(motor_id, 0)
current_position = self.read_position(motor_id)
print(f"position before {current_position}")
self.set_home_offset(motor_id, -current_position)
# dynamixel.set_home_offset(motor_id, -4096)
# dynamixel.set_home_offset(motor_id, -4294964109)
current_position = self.read_position(motor_id)
# print(f'signed position {current_position - 2** 32}')
print(f"position after {current_position}")
if __name__ == "__main__":
dynamixel = Dynamixel.Config(baudrate=1_000_000, device_name="/dev/tty.usbmodem57380045631").instantiate()
motor_id = 1
pos = dynamixel.read_position(motor_id)
for i in range(10):
s = time.monotonic()
pos = dynamixel.read_position(motor_id)
delta = time.monotonic() - s
print(f"read position took {delta}")
print(f"position {pos}")

View File

@@ -1,192 +0,0 @@
import time
from unittest.mock import MagicMock
import cv2
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from .dynamixel import pos2pwm, pwm2pos
from .robot import Robot
FPS = 30
CAMERAS_SHAPES = {
"images.high": (480, 640, 3),
"images.low": (480, 640, 3),
}
CAMERAS_PORTS = {
"images.high": "/dev/video6",
"images.low": "/dev/video0",
}
LEADER_PORT = "/dev/ttyACM1"
FOLLOWER_PORT = "/dev/ttyACM0"
MockRobot = MagicMock()
MockRobot.read_position = MagicMock()
MockRobot.read_position.return_value = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
MockCamera = MagicMock()
MockCamera.isOpened = MagicMock(return_value=True)
MockCamera.read = MagicMock(return_value=(True, np.zeros((480, 640, 3), dtype=np.uint8)))
def capture_image(cam, cam_width, cam_height):
# Capture a single frame
_, frame = cam.read()
image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# # Define your crop coordinates (top left corner and bottom right corner)
# x1, y1 = 400, 0 # Example starting coordinates (top left of the crop rectangle)
# x2, y2 = 1600, 900 # Example ending coordinates (bottom right of the crop rectangle)
# # Crop the image
# image = image[y1:y2, x1:x2]
# Resize the image
image = cv2.resize(image, (cam_width, cam_height), interpolation=cv2.INTER_AREA)
return image
class RealEnv(gym.Env):
metadata = {}
def __init__(
self,
record: bool = False,
num_joints: int = 6,
cameras_shapes: dict = CAMERAS_SHAPES,
cameras_ports: dict = CAMERAS_PORTS,
follower_port: str = FOLLOWER_PORT,
leader_port: str = LEADER_PORT,
warmup_steps: int = 100,
trigger_torque=70,
fps: int = FPS,
fps_tolerance: float = 0.1,
mock: bool = False,
):
self.num_joints = num_joints
self.cameras_shapes = cameras_shapes
self.cameras_ports = cameras_ports
self.warmup_steps = warmup_steps
assert len(self.cameras_shapes) == len(self.cameras_ports), "Number of cameras and shapes must match."
self.follower_port = follower_port
self.leader_port = leader_port
self.record = record
self.fps = fps
self.fps_tolerance = fps_tolerance
# Initialize the robot
self.follower = Robot(device_name=self.follower_port) if not mock else MockRobot
if self.record:
self.leader = Robot(device_name=self.leader_port) if not mock else MockRobot
self.leader.set_trigger_torque(trigger_torque)
# Initialize the cameras - sorted by camera names
self.cameras = {}
for cn, p in sorted(self.cameras_ports.items()):
self.cameras[cn] = cv2.VideoCapture(p) if not mock else MockCamera
if not self.cameras[cn].isOpened():
raise OSError(
f"Cannot open camera port {p} for {cn}."
f" Make sure the camera is connected and the port is correct."
f"Also check you are not spinning several instances of the same environment (eval.batch_size)"
)
# Specify gym action and observation spaces
observation_space = {}
if self.num_joints > 0:
observation_space["agent_pos"] = spaces.Box(
low=-1000.0,
high=1000.0,
shape=(num_joints,),
dtype=np.float64,
)
if self.record:
observation_space["leader_pos"] = spaces.Box(
low=-1000.0,
high=1000.0,
shape=(num_joints,),
dtype=np.float64,
)
if self.cameras_shapes:
for cn, hwc_shape in self.cameras_shapes.items():
# Assumes images are unsigned int8 in [0,255]
observation_space[cn] = spaces.Box(
low=0,
high=255,
# height x width x channels (e.g. 480 x 640 x 3)
shape=hwc_shape,
dtype=np.uint8,
)
self.observation_space = spaces.Dict(observation_space)
self.action_space = spaces.Box(low=-1, high=1, shape=(num_joints,), dtype=np.float32)
self._observation = {}
self._terminated = False
self.timestamps = []
def _get_obs(self):
qpos = self.follower.read_position()
self._observation["agent_pos"] = pwm2pos(qpos)
for cn, c in self.cameras.items():
self._observation[cn] = capture_image(c, self.cameras_shapes[cn][1], self.cameras_shapes[cn][0])
if self.record:
action = self.leader.read_position()
self._observation["leader_pos"] = pwm2pos(action)
def reset(self, seed: int | None = None):
# Reset the robot and sync the leader and follower if we are recording
for _ in range(self.warmup_steps):
self._get_obs()
if self.record:
self.follower.set_goal_pos(pos2pwm(self._observation["leader_pos"]))
self._terminated = False
info = {}
self.timestamps = []
return self._observation, info
def step(self, action: np.ndarray = None):
if self.timestamps:
# wait the right amount of time to stay at the desired fps
time.sleep(max(0, 1 / self.fps - (time.time() - self.timestamps[-1])))
self.timestamps.append(time.time())
# Get the observation
self._get_obs()
if self.record:
# Teleoperate the leader
self.follower.set_goal_pos(pos2pwm(self._observation["leader_pos"]))
else:
# Apply the action to the follower
self.follower.set_goal_pos(pos2pwm(action))
reward = 0
terminated = truncated = self._terminated
info = {"timestamp": self.timestamps[-1] - self.timestamps[0], "fps_error": False}
# Check if we are able to keep up with the desired fps
if len(self.timestamps) > 1 and (self.timestamps[-1] - self.timestamps[-2]) > 1 / (
self.fps - self.fps_tolerance
):
print(
f"Error: recording fps {1 / (self.timestamps[-1] - self.timestamps[-2]):.5f} is lower"
f" than min admited fps {(self.fps - self.fps_tolerance):.5f}"
f" at frame {len(self.timestamps)}"
)
info["fps_error"] = True
return self._observation, reward, terminated, truncated, info
def render(self): ...
def close(self):
self.follower._disable_torque()
if self.record:
self.leader._disable_torque()

View File

@@ -1,168 +0,0 @@
# ruff: noqa
"""From Alexander Koch low_cost_robot codebase at https://github.com/AlexanderKoch-Koch/low_cost_robot
Class to control the robot using dynamixel servos.
"""
from enum import Enum, auto
from typing import Union
import numpy as np
from dynamixel_sdk import DXL_HIBYTE, DXL_HIWORD, DXL_LOBYTE, DXL_LOWORD, GroupSyncRead, GroupSyncWrite
from .dynamixel import Dynamixel, OperatingMode, ReadAttribute
class MotorControlType(Enum):
PWM = auto()
POSITION_CONTROL = auto()
DISABLED = auto()
UNKNOWN = auto()
class Robot:
def __init__(self, device_name: str, baudrate=1_000_000, servo_ids=[1, 2, 3, 4, 5, 6]) -> None:
self.servo_ids = servo_ids
self.dynamixel = Dynamixel.Config(baudrate=baudrate, device_name=device_name).instantiate()
self._init_motors()
def _init_motors(self):
self.position_reader = GroupSyncRead(
self.dynamixel.portHandler, self.dynamixel.packetHandler, ReadAttribute.POSITION.value, 4
)
for id in self.servo_ids:
self.position_reader.addParam(id)
self.velocity_reader = GroupSyncRead(
self.dynamixel.portHandler, self.dynamixel.packetHandler, ReadAttribute.VELOCITY.value, 4
)
for id in self.servo_ids:
self.velocity_reader.addParam(id)
self.pos_writer = GroupSyncWrite(
self.dynamixel.portHandler, self.dynamixel.packetHandler, self.dynamixel.ADDR_GOAL_POSITION, 4
)
for id in self.servo_ids:
self.pos_writer.addParam(id, [2048])
self.pwm_writer = GroupSyncWrite(
self.dynamixel.portHandler, self.dynamixel.packetHandler, self.dynamixel.ADDR_GOAL_PWM, 2
)
for id in self.servo_ids:
self.pwm_writer.addParam(id, [2048])
self._disable_torque()
self.motor_control_state = MotorControlType.DISABLED
def read_position(self, tries=2):
"""
Reads the joint positions of the robot. 2048 is the center position. 0 and 4096 are 180 degrees in each direction.
:param tries: maximum number of tries to read the position
:return: list of joint positions in range [0, 4096]
"""
result = self.position_reader.txRxPacket()
if result != 0:
if tries > 0:
return self.read_position(tries=tries - 1)
else:
print("failed to read position!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
positions = []
for id in self.servo_ids:
position = self.position_reader.getData(id, ReadAttribute.POSITION.value, 4)
if position > 2**31:
position -= 2**32
positions.append(position)
return np.array(positions)
def read_velocity(self):
"""
Reads the joint velocities of the robot.
:return: list of joint velocities,
"""
self.velocity_reader.txRxPacket()
velocties = []
for id in self.servo_ids:
velocity = self.velocity_reader.getData(id, ReadAttribute.VELOCITY.value, 4)
if velocity > 2**31:
velocity -= 2**32
velocties.append(velocity)
return np.array(velocties)
def set_goal_pos(self, action):
"""
:param action: list or numpy array of target joint positions in range [0, 4096]
"""
if self.motor_control_state is not MotorControlType.POSITION_CONTROL:
self._set_position_control()
for i, motor_id in enumerate(self.servo_ids):
data_write = [
DXL_LOBYTE(DXL_LOWORD(action[i])),
DXL_HIBYTE(DXL_LOWORD(action[i])),
DXL_LOBYTE(DXL_HIWORD(action[i])),
DXL_HIBYTE(DXL_HIWORD(action[i])),
]
self.pos_writer.changeParam(motor_id, data_write)
self.pos_writer.txPacket()
def set_pwm(self, action):
"""
Sets the pwm values for the servos.
:param action: list or numpy array of pwm values in range [0, 885]
"""
if self.motor_control_state is not MotorControlType.PWM:
self._set_pwm_control()
for i, motor_id in enumerate(self.servo_ids):
data_write = [
DXL_LOBYTE(DXL_LOWORD(action[i])),
DXL_HIBYTE(DXL_LOWORD(action[i])),
]
self.pwm_writer.changeParam(motor_id, data_write)
self.pwm_writer.txPacket()
def set_trigger_torque(self, torque: int):
"""
Sets a constant torque torque for the last servo in the chain. This is useful for the trigger of the leader arm
"""
self.dynamixel._enable_torque(self.servo_ids[-1])
self.dynamixel.set_pwm_value(self.servo_ids[-1], torque)
def limit_pwm(self, limit: Union[int, list, np.ndarray]):
"""
Limits the pwm values for the servos in for position control
@param limit: 0 ~ 885
@return:
"""
if isinstance(limit, int):
limits = [
limit,
] * 5
else:
limits = limit
self._disable_torque()
for motor_id, limit in zip(self.servo_ids, limits, strict=False):
self.dynamixel.set_pwm_limit(motor_id, limit)
self._enable_torque()
def _disable_torque(self):
print(f"disabling torque for servos {self.servo_ids}")
for motor_id in self.servo_ids:
self.dynamixel._disable_torque(motor_id)
def _enable_torque(self):
print(f"enabling torque for servos {self.servo_ids}")
for motor_id in self.servo_ids:
self.dynamixel._enable_torque(motor_id)
def _set_pwm_control(self):
self._disable_torque()
for motor_id in self.servo_ids:
self.dynamixel.set_operating_mode(motor_id, OperatingMode.PWM)
self._enable_torque()
self.motor_control_state = MotorControlType.PWM
def _set_position_control(self):
self._disable_torque()
for motor_id in self.servo_ids:
self.dynamixel.set_operating_mode(motor_id, OperatingMode.POSITION)
self._enable_torque()
self.motor_control_state = MotorControlType.POSITION_CONTROL

View File

@@ -1,237 +0,0 @@
"""This script demonstrates how to record a LeRobot dataset of training data
using a very simple gym environment (see in examples/real_robot_example/gym_real_world/gym_environment.py).
"""
import argparse
import copy
import os
from pathlib import Path
import gym_real_world # noqa: F401
import gymnasium as gym
import numpy as np
import torch
from datasets import Dataset, Features, Sequence, Value
from omegaconf import OmegaConf
from tqdm import tqdm
from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, DATA_DIR, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
from lerobot.common.datasets.utils import (
hf_transform_to_torch,
)
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
from lerobot.scripts.push_dataset_to_hub import push_meta_data_to_hub, push_videos_to_hub, save_meta_data
# parse the repo_id name via command line
parser = argparse.ArgumentParser()
parser.add_argument("--repo-id", type=str, default="thomwolf/blue_red_sort")
parser.add_argument("--num-episodes", type=int, default=2)
parser.add_argument("--num-frames", type=int, default=400)
parser.add_argument("--num-workers", type=int, default=16)
parser.add_argument("--keep-last", action="store_true")
parser.add_argument("--data_dir", type=str, default=None)
parser.add_argument("--push-to-hub", action="store_true")
parser.add_argument("--fps", type=int, default=30, help="Frames per second of the recording.")
parser.add_argument(
"--fps_tolerance",
type=float,
default=0.5,
help="Tolerance in fps for the recording before dropping episodes.",
)
parser.add_argument(
"--revision", type=str, default=CODEBASE_VERSION, help="Codebase version used to generate the dataset."
)
parser.add_argument("--gym-config", type=str, default=None, help="Path to the gym config file.")
parser.add_argument("--mock_robot", action="store_true")
args = parser.parse_args()
repo_id = args.repo_id
num_episodes = args.num_episodes
num_frames = args.num_frames
revision = args.revision
fps = args.fps
fps_tolerance = args.fps_tolerance
out_data = DATA_DIR / repo_id if args.data_dir is None else Path(args.data_dir)
# During data collection, frames are stored as png images in `images_dir`
images_dir = out_data / "images"
# After data collection, png images of each episode are encoded into a mp4 file stored in `videos_dir`
videos_dir = out_data / "videos"
meta_data_dir = out_data / "meta_data"
gym_config = None
if args.config is not None:
gym_config = OmegaConf.load(args.config)
# Create image and video directories
if not os.path.exists(images_dir):
os.makedirs(images_dir, exist_ok=True)
if not os.path.exists(videos_dir):
os.makedirs(videos_dir, exist_ok=True)
if __name__ == "__main__":
# Create the gym environment - check the kwargs in gym_real_world/gym_environment.py
gym_handle = "gym_real_world/RealEnv-v0"
gym_kwargs = {}
if gym_config is not None:
gym_kwargs = OmegaConf.to_container(gym_config.gym_kwargs)
env = gym.make(
gym_handle, disable_env_checker=True, record=True, fps=fps, fps_tolerance=fps_tolerance, mock=True
)
ep_dicts = []
episode_data_index = {"from": [], "to": []}
ep_fps = []
id_from = 0
id_to = 0
os.system('spd-say "gym environment created"')
ep_idx = 0
while ep_idx < num_episodes:
# bring the follower to the leader and start camera
env.reset()
os.system(f'spd-say "go {ep_idx}"')
# init buffers
obs_replay = {k: [] for k in env.observation_space}
drop_episode = False
timestamps = []
for _ in tqdm(range(num_frames)):
# Apply the next action
observation, _, _, _, info = env.step(action=None)
# images_stacked = np.hstack(list(observation['pixels'].values()))
# images_stacked = cv2.cvtColor(images_stacked, cv2.COLOR_RGB2BGR)
# cv2.imshow('frame', images_stacked)
if info["fps_error"]:
os.system(f'spd-say "Error fps too low, dropping episode {ep_idx}"')
drop_episode = True
break
# store data
for key in observation:
obs_replay[key].append(copy.deepcopy(observation[key]))
timestamps.append(info["timestamp"])
# if cv2.waitKey(1) & 0xFF == ord('q'):
# break
os.system('spd-say "stop"')
if not drop_episode:
os.system(f'spd-say "saving episode {ep_idx}"')
ep_dict = {}
# store images in png and create the video
for img_key in env.cameras:
save_images_concurrently(
obs_replay[img_key],
images_dir / f"{img_key}_episode_{ep_idx:06d}",
args.num_workers,
)
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
# store the reference to the video frame
ep_dict[f"observation.{img_key}"] = [
{"path": f"videos/{fname}", "timestamp": tstp} for tstp in timestamps
]
state = torch.tensor(np.array(obs_replay["agent_pos"]))
action = torch.tensor(np.array(obs_replay["leader_pos"]))
next_done = torch.zeros(num_frames, dtype=torch.bool)
next_done[-1] = True
ep_dict["observation.state"] = state
ep_dict["action"] = action
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["timestamp"] = torch.tensor(timestamps)
ep_dict["next.done"] = next_done
ep_fps.append(num_frames / timestamps[-1])
ep_dicts.append(ep_dict)
print(f"Episode {ep_idx} done, fps: {ep_fps[-1]:.2f}")
episode_data_index["from"].append(id_from)
episode_data_index["to"].append(
id_from + num_frames if args.keep_last else id_from + num_frames - 1
)
id_to = id_from + num_frames if args.keep_last else id_from + num_frames - 1
id_from = id_to
ep_idx += 1
env.close()
os.system('spd-say "encode video frames"')
for ep_idx in range(num_episodes):
for img_key in env.cameras:
# If necessary, we may want to encode the video
# with variable frame rate: https://superuser.com/questions/1661901/encoding-video-from-vfr-still-images
encode_video_frames(
images_dir / f"{img_key}_episode_{ep_idx:06d}",
videos_dir / f"{img_key}_episode_{ep_idx:06d}.mp4",
ep_fps[ep_idx],
)
os.system('spd-say "concatenate episodes"')
data_dict = concatenate_episodes(
ep_dicts, drop_episodes_last_frame=not args.keep_last
) # Since our fps varies we are sometimes off tolerance for the last frame
features = {}
keys = [key for key in data_dict if "observation.images." in key]
for key in keys:
features[key] = VideoFrame()
features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
)
features["action"] = Sequence(
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
)
features["episode_index"] = Value(dtype="int64", id=None)
features["frame_index"] = Value(dtype="int64", id=None)
features["timestamp"] = Value(dtype="float32", id=None)
features["next.done"] = Value(dtype="bool", id=None)
features["index"] = Value(dtype="int64", id=None)
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
hf_dataset.set_transform(hf_transform_to_torch)
info = {
"fps": sum(ep_fps) / len(ep_fps), # to have a good tolerance in data processing for the slowest video
"video": 1,
}
os.system('spd-say "from preloaded"')
lerobot_dataset = LeRobotDataset.from_preloaded(
repo_id=repo_id,
version=revision,
hf_dataset=hf_dataset,
episode_data_index=episode_data_index,
info=info,
videos_dir=videos_dir,
)
os.system('spd-say "compute stats"')
stats = compute_stats(lerobot_dataset)
os.system('spd-say "save to disk"')
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
hf_dataset.save_to_disk(str(out_data / "train"))
save_meta_data(info, stats, episode_data_index, meta_data_dir)
if args.push_to_hub:
hf_dataset.push_to_hub(repo_id, token=True, revision="main")
hf_dataset.push_to_hub(repo_id, token=True, revision=revision)
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
push_meta_data_to_hub(repo_id, meta_data_dir, revision=revision)
push_videos_to_hub(repo_id, videos_dir, revision="main")
push_videos_to_hub(repo_id, videos_dir, revision=revision)

View File

@@ -1,60 +0,0 @@
import argparse
import logging
from pathlib import Path
import gym_real_world # noqa: F401
import gymnasium as gym # noqa: F401
from huggingface_hub import snapshot_download
from huggingface_hub.utils._errors import RepositoryNotFoundError
from huggingface_hub.utils._validators import HFValidationError
from lerobot.common.utils.utils import init_logging
from lerobot.scripts.eval import eval
if __name__ == "__main__":
init_logging()
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"-p",
"--pretrained-policy-name-or-path",
help=(
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
"(useful for debugging). This argument is mutually exclusive with `--config`."
),
)
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
parser.add_argument(
"overrides",
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
args = parser.parse_args()
try:
pretrained_policy_path = Path(
snapshot_download(args.pretrained_policy_name_or_path, revision=args.revision)
)
except (HFValidationError, RepositoryNotFoundError) as e:
if isinstance(e, HFValidationError):
error_message = (
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
)
else:
error_message = (
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
)
logging.warning(f"{error_message} Treating it as a local directory.")
pretrained_policy_path = Path(args.pretrained_policy_name_or_path)
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
raise ValueError(
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
"repo ID, nor is it an existing local directory."
)
eval(pretrained_policy_path=pretrained_policy_path, config_overrides=args.overrides)

View File

@@ -1,19 +0,0 @@
# @package _global_
fps: 30
env:
name: real_world
task: RealEnv-v0
state_dim: 6
action_dim: 6
fps: ${fps}
episode_length: 200
real_world: true
gym:
cameras_shapes:
images.high: [480, 640, 3]
images.low: [480, 640, 3]
cameras_ports:
images.high: /dev/video6
images.low: /dev/video0

View File

@@ -1,19 +0,0 @@
# @package _global_
fps: 30
env:
name: real_world
task: RealEnv-v0
state_dim: 6
action_dim: 6
fps: ${fps}
episode_length: 200
real_world: true
gym:
cameras_shapes:
images.top: [480, 640, 3]
images.front: [480, 640, 3]
cameras_ports:
images.top: /dev/video6
images.front: /dev/video0

View File

@@ -1,103 +0,0 @@
# @package _global_
# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets.
# Compared to `act.yaml`, it contains 4 cameras (i.e. right_wrist, left_wrist, images,
# low) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation.
# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot).
# Look at its README for more information on how to evaluate a checkpoint in the real-world.
#
# Example of usage for training:
# ```bash
# python lerobot/scripts/train.py \
# policy=act_real \
# env=aloha_real
# ```
seed: 1000
dataset_repo_id: ???
override_dataset_stats:
observation.images.high:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
observation.images.low:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
training:
offline_steps: 1000
online_steps: 0
eval_freq: -1
save_freq: 1000
log_freq: 100
save_checkpoint: true
batch_size: 8
lr: 1e-5
lr_backbone: 1e-5
weight_decay: 1e-4
grad_clip_norm: 10
online_steps_between_rollouts: 1
delta_timestamps:
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
eval:
n_episodes: 1
batch_size: 1
# See `configuration_act.py` for more details.
policy:
name: act
# Input / output structure.
n_obs_steps: 1
chunk_size: 100 # chunk_size
n_action_steps: 100
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.images.high: [3, 480, 640]
observation.images.low: [3, 480, 640]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
observation.images.high: mean_std
observation.images.low: mean_std
observation.state: mean_std
output_normalization_modes:
action: mean_std
# Architecture.
# Vision backbone.
vision_backbone: resnet18
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
replace_final_stride_with_dilation: false
# Transformer layers.
pre_norm: false
dim_model: 512
n_heads: 8
dim_feedforward: 3200
feedforward_activation: relu
n_encoder_layers: 4
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
n_decoder_layers: 1
# VAE.
use_vae: true
latent_dim: 32
n_vae_encoder_layers: 4
# Inference.
temporal_ensemble_momentum: null
# Training and loss computation.
dropout: 0.1
kl_weight: 10.0

View File

@@ -1,103 +0,0 @@
# @package _global_
# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets.
# Compared to `act.yaml`, it contains 4 cameras (i.e. right_wrist, left_wrist, images,
# front) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation.
# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot).
# Look at its README for more information on how to evaluate a checkpoint in the real-world.
#
# Example of usage for training:
# ```bash
# python lerobot/scripts/train.py \
# policy=act_real \
# env=aloha_real
# ```
seed: 1000
dataset_repo_id: ???
override_dataset_stats:
observation.images.top:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
observation.images.front:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
training:
offline_steps: 1000
online_steps: 0
eval_freq: -1
save_freq: 1000
log_freq: 100
save_checkpoint: true
batch_size: 8
lr: 1e-5
lr_backbone: 1e-5
weight_decay: 1e-4
grad_clip_norm: 10
online_steps_between_rollouts: 1
delta_timestamps:
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
eval:
n_episodes: 1
batch_size: 1
# See `configuration_act.py` for more details.
policy:
name: act
# Input / output structure.
n_obs_steps: 1
chunk_size: 100 # chunk_size
n_action_steps: 100
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.images.top: [3, 480, 640]
observation.images.front: [3, 480, 640]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
observation.images.top: mean_std
observation.images.front: mean_std
observation.state: mean_std
output_normalization_modes:
action: mean_std
# Architecture.
# Vision backbone.
vision_backbone: resnet18
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
replace_final_stride_with_dilation: false
# Transformer layers.
pre_norm: false
dim_model: 512
n_heads: 8
dim_feedforward: 3200
feedforward_activation: relu
n_encoder_layers: 4
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
n_decoder_layers: 1
# VAE.
use_vae: true
latent_dim: 32
n_vae_encoder_layers: 4
# Inference.
temporal_ensemble_momentum: null
# Training and loss computation.
dropout: 0.1
kl_weight: 10.0

View File

@@ -45,9 +45,6 @@ import itertools
from lerobot.__version__ import __version__ # noqa: F401
# TODO(rcadene): Improve policies and envs. As of now, an item in `available_policies`
# refers to a yaml file AND a modeling name. Same for `available_envs` which refers to
# a yaml file AND a environment name. The difference should be more obvious.
available_tasks_per_env = {
"aloha": [
"AlohaInsertion-v0",
@@ -55,7 +52,7 @@ available_tasks_per_env = {
],
"pusht": ["PushT-v0"],
"xarm": ["XarmLift-v0"],
"dora_aloha_real": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"],
"dora": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"],
}
available_envs = list(available_tasks_per_env.keys())
@@ -81,7 +78,7 @@ available_datasets_per_env = {
"lerobot/xarm_push_medium_image",
"lerobot/xarm_push_medium_replay_image",
],
"dora_aloha_real": [
"dora": [
"lerobot/aloha_static_battery",
"lerobot/aloha_static_candy",
"lerobot/aloha_static_coffee",
@@ -129,19 +126,17 @@ available_datasets = list(
itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets)
)
# lists all available policies from `lerobot/common/policies` by their class attribute: `name`.
available_policies = [
"act",
"diffusion",
"tdmpc",
]
# keys and values refer to yaml files
available_policies_per_env = {
"aloha": ["act"],
"dora": ["act"],
"pusht": ["diffusion"],
"xarm": ["tdmpc"],
"dora_aloha_real": ["act_real"],
}
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]

View File

@@ -16,9 +16,9 @@
import logging
import torch
from omegaconf import ListConfig, OmegaConf
from omegaconf import OmegaConf
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
def resolve_delta_timestamps(cfg):
@@ -35,54 +35,25 @@ def resolve_delta_timestamps(cfg):
cfg.training.delta_timestamps[key] = eval(delta_timestamps[key])
def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotDataset:
"""
Args:
cfg: A Hydra config as per the LeRobot config scheme.
split: Select the data subset used to create an instance of LeRobotDataset.
All datasets hosted on [lerobot](https://huggingface.co/lerobot) contain only one subset: "train".
Thus, by default, `split="train"` selects all the available data. `split` aims to work like the
slicer in the hugging face datasets:
https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
As of now, it only supports `split="train[:n]"` to load the first n frames of the dataset or
`split="train[n:]"` to load the last n frames. For instance `split="train[:1000]"`.
Returns:
The LeRobotDataset.
"""
if not isinstance(cfg.dataset_repo_id, (str, ListConfig)):
raise ValueError(
"Expected cfg.dataset_repo_id to be either a single string to load one dataset or a list of "
"strings to load multiple datasets."
def make_dataset(
cfg,
split="train",
):
if cfg.env.name not in cfg.dataset_repo_id:
logging.warning(
f"There might be a mismatch between your training dataset ({cfg.dataset_repo_id=}) and your "
f"environment ({cfg.env.name=})."
)
# A soft check to warn if the environment matches the dataset. Don't check if we are using a real world env (dora).
if not cfg.env.real_world:
if isinstance(cfg.dataset_repo_id, str):
dataset_repo_ids = [cfg.dataset_repo_id] # single dataset
else:
dataset_repo_ids = cfg.dataset_repo_id # multiple datasets
for dataset_repo_id in dataset_repo_ids:
if cfg.env.name not in dataset_repo_id:
logging.warning(
f"There might be a mismatch between your training dataset ({dataset_repo_id=}) and your "
f"environment ({cfg.env.name=})."
)
resolve_delta_timestamps(cfg)
# TODO(rcadene): add data augmentations
if isinstance(cfg.dataset_repo_id, str):
dataset = LeRobotDataset(
cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"),
)
else:
dataset = MultiLeRobotDataset(
cfg.dataset_repo_id, split=split, delta_timestamps=cfg.training.get("delta_timestamps")
)
dataset = LeRobotDataset(
cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"),
)
if cfg.get("override_dataset_stats"):
for key, stats_dict in cfg.override_dataset_stats.items():

View File

@@ -13,16 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from pathlib import Path
from typing import Callable
import datasets
import torch
import torch.utils
from lerobot.common.datasets.compute_stats import aggregate_stats
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
load_episode_data_index,
@@ -46,7 +42,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
version: str | None = CODEBASE_VERSION,
root: Path | None = DATA_DIR,
split: str = "train",
transform: Callable | None = None,
transform: callable = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
@@ -175,7 +171,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
@classmethod
def from_preloaded(
cls,
repo_id: str = "from_preloaded",
repo_id: str,
version: str | None = CODEBASE_VERSION,
root: Path | None = None,
split: str = "train",
@@ -187,15 +183,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
stats=None,
info=None,
videos_dir=None,
) -> "LeRobotDataset":
"""Create a LeRobot Dataset from existing data and attributes instead of loading from the filesystem.
It is especially useful when converting raw data into LeRobotDataset before saving the dataset
on the filesystem or uploading to the hub.
Note: Meta-data attributes like `repo_id`, `version`, `root`, etc are optional and potentially
meaningless depending on the downstream usage of the return dataset.
"""
):
# create an empty object of type LeRobotDataset
obj = cls.__new__(cls)
obj.repo_id = repo_id
@@ -207,193 +195,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.hf_dataset = hf_dataset
obj.episode_data_index = episode_data_index
obj.stats = stats
obj.info = info if info is not None else {}
obj.info = info
obj.videos_dir = videos_dir
return obj
class MultiLeRobotDataset(torch.utils.data.Dataset):
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API
structure of `LeRobotDataset`.
"""
def __init__(
self,
repo_ids: list[str],
version: str | None = CODEBASE_VERSION,
root: Path | None = DATA_DIR,
split: str = "train",
transform: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
self.repo_ids = repo_ids
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
# are handled by this class.
self._datasets = [
LeRobotDataset(
repo_id,
version=version,
root=root,
split=split,
delta_timestamps=delta_timestamps,
transform=transform,
)
for repo_id in repo_ids
]
# Check that some properties are consistent across datasets. Note: We may relax some of these
# consistency requirements in future iterations of this class.
for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True):
if dataset.info != self._datasets[0].info:
raise ValueError(
f"Detected a mismatch in dataset info between {self.repo_ids[0]} and {repo_id}. This is "
"not yet supported."
)
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
# restriction in future iterations of this class. For now, this is necessary at least for being able
# to use PyTorch's default DataLoader collate function.
self.disabled_data_keys = set()
intersection_data_keys = set(self._datasets[0].hf_dataset.features)
for dataset in self._datasets:
intersection_data_keys.intersection_update(dataset.hf_dataset.features)
if len(intersection_data_keys) == 0:
raise RuntimeError(
"Multiple datasets were provided but they had no keys common to all of them. The "
"multi-dataset functionality currently only keeps common keys."
)
for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True):
extra_keys = set(dataset.hf_dataset.features).difference(intersection_data_keys)
logging.warning(
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
"other datasets."
)
self.disabled_data_keys.update(extra_keys)
self.version = version
self.root = root
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
self.stats = aggregate_stats(self._datasets)
@property
def repo_id_to_index(self):
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class.
This index is incorporated as a data key in the dictionary returned by `__getitem__`.
"""
return {repo_id: i for i, repo_id in enumerate(self.repo_ids)}
@property
def repo_index_to_id(self):
"""Return the inverse mapping if repo_id_to_index."""
return {v: k for k, v in self.repo_id_to_index}
@property
def fps(self) -> int:
"""Frames per second used during data collection.
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
"""
return self._datasets[0].info["fps"]
@property
def video(self) -> bool:
"""Returns True if this dataset loads video frames from mp4 files.
Returns False if it only loads images from png files.
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
"""
return self._datasets[0].info.get("video", False)
@property
def features(self) -> datasets.Features:
features = {}
for dataset in self._datasets:
features.update({k: v for k, v in dataset.features.items() if k not in self.disabled_data_keys})
return features
@property
def camera_keys(self) -> list[str]:
"""Keys to access image and video stream from cameras."""
keys = []
for key, feats in self.features.items():
if isinstance(feats, (datasets.Image, VideoFrame)):
keys.append(key)
return keys
@property
def video_frame_keys(self) -> list[str]:
"""Keys to access video frames that requires to be decoded into images.
Note: It is empty if the dataset contains images only,
or equal to `self.cameras` if the dataset contains videos only,
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
"""
video_frame_keys = []
for key, feats in self.features.items():
if isinstance(feats, VideoFrame):
video_frame_keys.append(key)
return video_frame_keys
@property
def num_samples(self) -> int:
"""Number of samples/frames."""
return sum(d.num_samples for d in self._datasets)
@property
def num_episodes(self) -> int:
"""Number of episodes."""
return sum(d.num_episodes for d in self._datasets)
@property
def tolerance_s(self) -> float:
"""Tolerance in seconds used to discard loaded frames when their timestamps
are not close enough from the requested frames. It is only used when `delta_timestamps`
is provided or when loading video frames from mp4 files.
"""
# 1e-4 to account for possible numerical error
return 1 / self.fps - 1e-4
def __len__(self):
return self.num_samples
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
if idx >= len(self):
raise IndexError(f"Index {idx} out of bounds.")
# Determine which dataset to get an item from based on the index.
start_idx = 0
dataset_idx = 0
for dataset in self._datasets:
if idx >= start_idx + dataset.num_samples:
start_idx += dataset.num_samples
dataset_idx += 1
continue
break
else:
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
item = self._datasets[dataset_idx][idx - start_idx]
item["dataset_index"] = torch.tensor(dataset_idx)
for data_key in self.disabled_data_keys:
if data_key in item:
del item[data_key]
return item
def __repr__(self):
return (
f"{self.__class__.__name__}(\n"
f" Repository IDs: '{self.repo_ids}',\n"
f" Version: '{self.version}',\n"
f" Split: '{self.split}',\n"
f" Number of Samples: {self.num_samples},\n"
f" Number of Episodes: {self.num_episodes},\n"
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
f" Recorded Frames per Second: {self.fps},\n"
f" Camera Keys: {self.camera_keys},\n"
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
f" Transformations: {self.transform},\n"
f")"
)

View File

@@ -78,15 +78,29 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
image_keys = [key for key in df if "observation.images." in key]
num_unaligned_images = 0
max_episode = 0
def get_episode_index(row):
nonlocal num_unaligned_images
nonlocal max_episode
episode_index_per_cam = {}
for key in image_keys:
if isinstance(row[key], float):
num_unaligned_images += 1
return float("nan")
path = row[key][0]["path"]
match = re.search(r"_(\d{6}).mp4", path)
if not match:
raise ValueError(path)
episode_index = int(match.group(1))
episode_index_per_cam[key] = episode_index
if episode_index > max_episode:
assert episode_index - max_episode == 1
max_episode = episode_index
else:
assert episode_index == max_episode
if len(set(episode_index_per_cam.values())) != 1:
raise ValueError(
f"All cameras are expected to belong to the same episode, but getting {episode_index_per_cam}"
@@ -111,11 +125,24 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
del df["timestamp_utc"]
# sanity check
has_nan = df.isna().any().any()
if has_nan:
raise ValueError("Dataset contains Nan values.")
num_rows_with_nan = df.isna().any(axis=1).sum()
assert (
num_rows_with_nan == num_unaligned_images
), f"Found {num_rows_with_nan} rows with NaN values but {num_unaligned_images} unaligned images."
if num_unaligned_images > max_episode * 2:
# We allow a few unaligned images, typically at the beginning and end of the episodes for instance
# but if there are too many, we raise an error to avoid large chunks of missing data
raise ValueError(
f"Found {num_unaligned_images} unaligned images out of {max_episode} episodes. "
f"Check the timestamps of the cameras."
)
# Drop rows with NaN values now that we double checked and convert episode_index to int
df = df.dropna()
df["episode_index"] = df["episode_index"].astype(int)
# sanity check episode indices go from 0 to n-1
assert df["episode_index"].max() == max_episode
ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")]
expected_ep_ids = list(range(df["episode_index"].max() + 1))
if ep_ids != expected_ep_ids:
@@ -214,8 +241,6 @@ def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=Tru
if fps is None:
fps = 30
else:
raise NotImplementedError()
if not video:
raise NotImplementedError()

View File

@@ -43,6 +43,9 @@ def get_cameras(hdf5_data):
def check_format(raw_dir) -> bool:
# only frames from simulation are uncompressed
compressed_images = "sim" not in raw_dir.name
hdf5_paths = list(raw_dir.glob("episode_*.hdf5"))
assert len(hdf5_paths) != 0
for hdf5_path in hdf5_paths:
@@ -59,15 +62,17 @@ def check_format(raw_dir) -> bool:
for camera in get_cameras(data):
assert num_frames == data[f"/observations/images/{camera}"].shape[0]
# ndim 2 when image are compressed and 4 when uncompressed
assert data[f"/observations/images/{camera}"].ndim in [2, 4]
if data[f"/observations/images/{camera}"].ndim == 4:
if compressed_images:
assert data[f"/observations/images/{camera}"].ndim == 2
else:
assert data[f"/observations/images/{camera}"].ndim == 4
b, h, w, c = data[f"/observations/images/{camera}"].shape
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
def load_from_raw(raw_dir, out_dir, fps, video, debug):
# only frames from simulation are uncompressed
compressed_images = "sim" not in raw_dir.name
hdf5_files = list(raw_dir.glob("*.hdf5"))
ep_dicts = []
@@ -94,7 +99,7 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
for camera in get_cameras(ep):
img_key = f"observation.images.{camera}"
if ep[f"/observations/images/{camera}"].ndim == 2:
if compressed_images:
import cv2
# load one compressed image after the other in RAM and uncompress

View File

@@ -16,15 +16,17 @@
from copy import deepcopy
from math import ceil
import datasets
import einops
import torch
import tqdm
from datasets import Image
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.video_utils import VideoFrame
def get_stats_einops_patterns(dataset, num_workers=0):
def get_stats_einops_patterns(dataset: LeRobotDataset | datasets.Dataset, num_workers=0):
"""These einops patterns will be used to aggregate batches and compute statistics.
Note: We assume the images are in channel first format
@@ -64,8 +66,9 @@ def get_stats_einops_patterns(dataset, num_workers=0):
return stats_patterns
def compute_stats(dataset, batch_size=32, num_workers=16, max_num_samples=None):
"""Compute mean/std and min/max statistics of all data keys in a LeRobotDataset."""
def compute_stats(
dataset: LeRobotDataset | datasets.Dataset, batch_size=32, num_workers=16, max_num_samples=None
):
if max_num_samples is None:
max_num_samples = len(dataset)
@@ -156,54 +159,3 @@ def compute_stats(dataset, batch_size=32, num_workers=16, max_num_samples=None):
"min": min[key],
}
return stats
def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
"""Aggregate stats of multiple LeRobot datasets into one set of stats without recomputing from scratch.
The final stats will have the union of all data keys from each of the datasets.
The final stats will have the union of all data keys from each of the datasets. For instance:
- new_max = max(max_dataset_0, max_dataset_1, ...)
- new_min = min(min_dataset_0, min_dataset_1, ...)
- new_mean = (mean of all data)
- new_std = (std of all data)
"""
data_keys = set()
for dataset in ls_datasets:
data_keys.update(dataset.stats.keys())
stats = {k: {} for k in data_keys}
for data_key in data_keys:
for stat_key in ["min", "max"]:
# compute `max(dataset_0["max"], dataset_1["max"], ...)`
stats[data_key][stat_key] = einops.reduce(
torch.stack([d.stats[data_key][stat_key] for d in ls_datasets if data_key in d.stats], dim=0),
"n ... -> ...",
stat_key,
)
total_samples = sum(d.num_samples for d in ls_datasets if data_key in d.stats)
# Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
# dataset, then divide by total_samples to get the overall "mean".
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
# numerical overflow!
stats[data_key]["mean"] = sum(
d.stats[data_key]["mean"] * (d.num_samples / total_samples)
for d in ls_datasets
if data_key in d.stats
)
# The derivation for standard deviation is a little more involved but is much in the same spirit as
# the computation of the mean.
# Given two sets of data where the statistics are known:
# σ_combined = sqrt[ (n1 * (σ1^2 + d1^2) + n2 * (σ2^2 + d2^2)) / (n1 + n2) ]
# where d1 = μ1 - μ_combined, d2 = μ2 - μ_combined
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
# numerical overflow!
stats[data_key]["std"] = torch.sqrt(
sum(
(d.stats[data_key]["std"] ** 2 + (d.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2)
* (d.num_samples / total_samples)
for d in ls_datasets
if data_key in d.stats
)
)
return stats

View File

@@ -21,24 +21,19 @@ import PIL
import torch
def concatenate_episodes(ep_dicts, drop_episodes_last_frame=False):
def concatenate_episodes(ep_dicts):
data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
if torch.is_tensor(ep_dicts[0][key][0]):
if drop_episodes_last_frame:
data_dict[key] = torch.cat([ep_dict[key][:-1] for ep_dict in ep_dicts])
else:
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
else:
if key not in data_dict:
data_dict[key] = []
for ep_dict in ep_dicts:
for x in ep_dict[key]:
data_dict[key].append(x)
if drop_episodes_last_frame:
data_dict[key].pop()
total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1)

View File

@@ -1,61 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterator, Union
import torch
class EpisodeAwareSampler:
def __init__(
self,
episode_data_index: dict,
episode_indices_to_use: Union[list, None] = None,
drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0,
shuffle: bool = False,
):
"""Sampler that optionally incorporates episode boundary information.
Args:
episode_data_index: Dictionary with keys 'from' and 'to' containing the start and end indices of each episode.
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
Assumes that episodes are indexed from 0 to N-1.
drop_n_first_frames: Number of frames to drop from the start of each episode.
drop_n_last_frames: Number of frames to drop from the end of each episode.
shuffle: Whether to shuffle the indices.
"""
indices = []
for episode_idx, (start_index, end_index) in enumerate(
zip(episode_data_index["from"], episode_data_index["to"], strict=True)
):
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
indices.extend(
range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames)
)
self.indices = indices
self.shuffle = shuffle
def __iter__(self) -> Iterator[int]:
if self.shuffle:
for i in torch.randperm(len(self.indices)):
yield self.indices[i]
else:
for i in self.indices:
yield i
def __len__(self) -> int:
return len(self.indices)

View File

@@ -59,7 +59,7 @@ def unflatten_dict(d, sep="/"):
return outdict
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
def hf_transform_to_torch(items_dict):
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
to torch tensors. Importantly, images are converted from PIL, which corresponds to
a channel last representation (h w c) of uint8 type, to a torch image representation
@@ -73,8 +73,6 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item:
# video frame will be processed downstream
pass
elif first_item is None:
pass
else:
items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
return items_dict
@@ -320,7 +318,8 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torc
def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
"""Reset the `episode_index` of the provided HuggingFace Dataset.
"""
Reset the `episode_index` of the provided HuggingFace Dataset.
`episode_data_index` (and related functionality such as `load_previous_and_future_frames`) requires the
`episode_index` to be sorted, continuous (1,1,1 and not 1,2,1) and start at 0.
@@ -339,7 +338,6 @@ def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
return example
hf_dataset = hf_dataset.map(modify_ep_idx_func)
return hf_dataset

View File

@@ -29,12 +29,10 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
# map to expected inputs for the policy
return_observations = {}
if "pixels" in observations and isinstance(observations["pixels"], dict):
if isinstance(observations["pixels"], dict):
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
elif "pixels" in observations and isinstance(observations["pixels"], np.ndarray):
imgs = {"observation.image": observations["pixels"]}
else:
imgs = {f"observation.{key}": img for key, img in observations.items() if "images" in key}
imgs = {"observation.image": observations["pixels"]}
for imgkey, img in imgs.items():
img = torch.from_numpy(img)

View File

@@ -26,10 +26,11 @@ class ACTConfig:
Those are: `input_shapes` and 'output_shapes`.
Notes on the inputs and outputs:
- "observation.state" is required as an input key.
- At least one key starting with "observation.image is required as an input.
- If there are multiple keys beginning with "observation.images." they are treated as multiple camera
views. Right now we only support all images having the same shape.
- May optionally work without an "observation.state" key for the proprioceptive robot state.
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
views.
Right now we only support all images having the same shape.
- "action" is required as an output key.
Args:
@@ -129,9 +130,7 @@ class ACTConfig:
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
# As a consequence we also remove the final, unused layer normalization, by default
n_decoder_layers: int = 1
decoder_norm: bool = False
# VAE.
use_vae: bool = True
latent_dim: int = 32

View File

@@ -198,14 +198,15 @@ class ACT(nn.Module):
def __init__(self, config: ACTConfig):
super().__init__()
self.config = config
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
self.use_input_state = "observation.state" in config.input_shapes
self.has_state = "observation.state" in config.input_shapes
self.latent_dim = config.latent_dim
if self.config.use_vae:
self.vae_encoder = ACTEncoder(config)
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
# Projection layer for joint-space configuration to hidden dimension.
if self.use_input_state:
if self.has_state:
self.vae_encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model
)
@@ -214,12 +215,10 @@ class ACT(nn.Module):
config.output_shapes["action"][0], config.dim_model
)
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2)
# Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch
# dimension.
num_input_token_encoder = 1 + config.chunk_size
if self.use_input_state:
num_input_token_encoder += 1
num_input_token_encoder = 1 + 1 + config.chunk_size if self.has_state else 1 + config.chunk_size
self.register_buffer(
"vae_encoder_pos_enc",
create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
@@ -242,16 +241,16 @@ class ACT(nn.Module):
# Transformer encoder input projections. The tokens will be structured like
# [latent, robot_state, image_feature_map_pixels].
if self.use_input_state:
if self.has_state:
self.encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model
)
self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, config.dim_model)
self.encoder_img_feat_input_proj = nn.Conv2d(
backbone_model.fc.in_features, config.dim_model, kernel_size=1
)
# Transformer encoder positional embeddings.
num_input_token_decoder = 2 if self.use_input_state else 1
num_input_token_decoder = 2 if self.has_state else 1
self.encoder_robot_and_latent_pos_embed = nn.Embedding(num_input_token_decoder, config.dim_model)
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
@@ -299,12 +298,12 @@ class ACT(nn.Module):
cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D)
if self.use_input_state:
if self.has_state:
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
if self.use_input_state:
if self.has_state:
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
else:
vae_encoder_input = [cls_embed, action_embed]
@@ -315,19 +314,13 @@ class ACT(nn.Module):
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
# Forward pass through VAE encoder to get the latent PDF parameters.
cls_joint_is_pad = torch.full((batch_size, 2), False).to(
batch["observation.state"].device
) # False: not a padding
key_padding_mask = torch.cat([cls_joint_is_pad, batch["action_is_pad"]], axis=1) # (bs, seq+1)
cls_token_out = self.vae_encoder(
vae_encoder_input.permute(1, 0, 2),
pos_embed=pos_embed.permute(1, 0, 2),
key_padding_mask=key_padding_mask,
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
)[0] # select the class token, with shape (B, D)
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
mu = latent_pdf_params[:, : self.config.latent_dim]
mu = latent_pdf_params[:, : self.latent_dim]
# This is 2log(sigma). Done this way to match the original implementation.
log_sigma_x2 = latent_pdf_params[:, self.config.latent_dim :]
log_sigma_x2 = latent_pdf_params[:, self.latent_dim :]
# Sample the latent with the reparameterization trick.
latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
@@ -335,7 +328,7 @@ class ACT(nn.Module):
# When not using the VAE encoder, we set the latent to be all zeros.
mu = log_sigma_x2 = None
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
batch["observation.state"].device
)
@@ -357,12 +350,12 @@ class ACT(nn.Module):
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
# Get positional embeddings for robot state and latent.
if self.use_input_state:
if self.has_state:
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
# Stack encoder input and positional embeddings moving to (S, B, C).
encoder_in_feats = [latent_embed, robot_state_embed] if self.use_input_state else [latent_embed]
encoder_in_feats = [latent_embed, robot_state_embed] if self.has_state else [latent_embed]
encoder_in = torch.cat(
[
torch.stack(encoder_in_feats, axis=0),
@@ -408,11 +401,9 @@ class ACTEncoder(nn.Module):
self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(config.n_encoder_layers)])
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
def forward(
self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
) -> Tensor:
def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor:
for layer in self.layers:
x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)
x = layer(x, pos_embed=pos_embed)
x = self.norm(x)
return x
@@ -435,14 +426,12 @@ class ACTEncoderLayer(nn.Module):
self.activation = get_activation_fn(config.feedforward_activation)
self.pre_norm = config.pre_norm
def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor:
def forward(self, x, pos_embed: Tensor | None = None) -> Tensor:
skip = x
if self.pre_norm:
x = self.norm1(x)
q = k = x if pos_embed is None else x + pos_embed
x = self.self_attn(q, k, value=x, key_padding_mask=key_padding_mask)[
0
] # select just the output, not the attention weights
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
x = skip + self.dropout1(x)
if self.pre_norm:
skip = x
@@ -462,10 +451,7 @@ class ACTDecoder(nn.Module):
"""Convenience module for running multiple decoder layers followed by normalization."""
super().__init__()
self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)])
if config.decoder_norm:
self.norm = nn.LayerNorm(config.dim_model)
else:
self.norm = nn.Identity()
self.norm = nn.LayerNorm(config.dim_model)
def forward(
self,
@@ -478,7 +464,8 @@ class ACTDecoder(nn.Module):
x = layer(
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
)
x = self.norm(x)
if self.norm is not None:
x = self.norm(x)
return x

View File

@@ -28,7 +28,10 @@ class DiffusionConfig:
Notes on the inputs and outputs:
- "observation.state" is required as an input key.
- A key starting with "observation.image is required as an input.
- At least one key starting with "observation.image is required as an input.
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
views.
Right now we only support all images having the same shape.
- "action" is required as an output key.
Args:

View File

@@ -239,8 +239,10 @@ class DiffusionModel(nn.Module):
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
# run sampling
actions = self.conditional_sample(batch_size, global_cond=global_cond)
sample = self.conditional_sample(batch_size, global_cond=global_cond)
# `horizon` steps worth of actions (from the first observation).
actions = sample[..., : self.config.output_shapes["action"][0]]
# Extract `n_action_steps` steps worth of actions (from the current observation).
start = n_obs_steps - 1
end = start + self.config.n_action_steps

View File

@@ -147,7 +147,7 @@ class Normalize(nn.Module):
assert not torch.isinf(min).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max")
# normalize to [0,1]
batch[key] = (batch[key] - min) / (max - min + 1e-8)
batch[key] = (batch[key] - min) / (max - min)
# normalize to [-1, 1]
batch[key] = batch[key] * 2 - 1
else:

View File

@@ -120,13 +120,13 @@ def init_logging():
logging.getLogger().addHandler(console_handler)
def format_big_number(num, precision=0):
def format_big_number(num):
suffixes = ["", "K", "M", "B", "T", "Q"]
divisor = 1000.0
for suffix in suffixes:
if abs(num) < divisor:
return f"{num:.{precision}f}{suffix}"
return f"{num:.0f}{suffix}"
num /= divisor
return num

View File

@@ -23,10 +23,6 @@ use_amp: false
# `seed` is used for training (eg: model initialization, dataset shuffling)
# AND for the evaluation environments.
seed: ???
# You may provide a list of datasets here. `train.py` creates them all and concatenates them. Note: only data
# keys common between the datasets are kept. Each dataset gets and additional transform that inserts the
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
# datsets are provided.
dataset_repo_id: lerobot/pusht
training:
@@ -50,8 +46,6 @@ eval:
batch_size: 1
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
use_async_envs: false
# Specify the number of episodes to render during evaluation.
max_episodes_rendered: 10
wandb:
enable: false

View File

@@ -9,7 +9,6 @@ env:
action_dim: 14
fps: ${fps}
episode_length: 400
real_world: false
gym:
obs_type: pixels_agent_pos
render_mode: rgb_array

View File

@@ -9,6 +9,5 @@ env:
action_dim: 14
fps: ${fps}
episode_length: 400
real_world: true
gym:
fps: ${fps}

View File

@@ -10,7 +10,6 @@ env:
action_dim: 2
fps: ${fps}
episode_length: 300
real_world: false
gym:
obs_type: pixels_agent_pos
render_mode: rgb_array

View File

@@ -10,7 +10,6 @@ env:
action_dim: 4
fps: ${fps}
episode_length: 25
real_world: false
gym:
obs_type: pixels_agent_pos
render_mode: rgb_array

View File

@@ -11,7 +11,7 @@
# ```bash
# python lerobot/scripts/train.py \
# policy=act_real \
# env=dora_aloha_real
# env=aloha_real
# ```
seed: 1000

View File

@@ -9,7 +9,7 @@
# ```bash
# python lerobot/scripts/train.py \
# policy=act_real_no_state \
# env=dora_aloha_real
# env=aloha_real
# ```
seed: 1000

View File

@@ -44,10 +44,6 @@ training:
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]"
# The original implementation doesn't sample frames for the last 7 steps,
# which avoids excessive padding and leads to improved training results.
drop_n_last_frames: 7 # ${policy.horizon} - ${policy.n_action_steps} - ${policy.n_obs_steps} + 1
eval:
n_episodes: 50
batch_size: 50

View File

@@ -44,7 +44,6 @@ https://huggingface.co/lerobot/diffusion_pusht/tree/main.
import argparse
import json
import logging
import os
import threading
import time
from contextlib import nullcontext
@@ -165,10 +164,7 @@ def rollout(
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
# available of none of the envs finished.
if "final_info" in info:
successes = [
i["is_success"] if (i is not None and "is_success" in i) else False
for i in info["final_info"]
]
successes = [info["is_success"] if info is not None else False for info in info["final_info"]]
else:
successes = [False] * env.num_envs
@@ -520,7 +516,6 @@ def eval(
out_dir = (
f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
)
os.makedirs(out_dir, exist_ok=True)
if out_dir is None:
raise NotImplementedError()
@@ -550,7 +545,7 @@ def eval(
env,
policy,
hydra_cfg.eval.n_episodes,
max_episodes_rendered=hydra_cfg.eval.max_episodes_rendered,
max_episodes_rendered=10,
video_dir=Path(out_dir) / "eval",
start_seed=hydra_cfg.seed,
enable_progbar=True,

View File

@@ -71,9 +71,9 @@ import torch
from huggingface_hub import HfApi
from safetensors.torch import save_file
from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
from lerobot.common.datasets.push_dataset_to_hub.compute_stats import compute_stats
from lerobot.common.datasets.utils import flatten_dict

View File

@@ -16,6 +16,7 @@
import logging
import time
from contextlib import nullcontext
from copy import deepcopy
from pathlib import Path
from pprint import pformat
@@ -27,8 +28,6 @@ from termcolor import colored
from torch.cuda.amp import GradScaler
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset
from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir
@@ -281,11 +280,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info("make_dataset")
offline_dataset = make_dataset(cfg)
if isinstance(offline_dataset, MultiLeRobotDataset):
logging.info(
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
f"{pformat(offline_dataset.repo_id_to_index , indent=2)}"
)
# Create environment used for evaluating checkpoints during training on simulation data.
# On real-world data, no need to create an environment as evaluations are done outside train.py,
@@ -336,7 +330,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
max_episodes_rendered=4,
start_seed=cfg.seed,
)
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline=True)
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
if cfg.wandb.enable:
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
logging.info("Resume training")
@@ -357,28 +351,18 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info("Resume training")
# create dataloader for offline training
if cfg.training.get("drop_n_last_frames"):
shuffle = False
sampler = EpisodeAwareSampler(
offline_dataset.episode_data_index,
drop_n_last_frames=cfg.training.drop_n_last_frames,
shuffle=True,
)
else:
shuffle = True
sampler = None
dataloader = torch.utils.data.DataLoader(
offline_dataset,
num_workers=cfg.training.num_workers,
batch_size=cfg.training.batch_size,
shuffle=shuffle,
sampler=sampler,
shuffle=True,
pin_memory=device.type != "cpu",
drop_last=False,
)
dl_iter = cycle(dataloader)
policy.train()
is_offline = True
for _ in range(step, cfg.training.offline_steps):
if step == 0:
logging.info("Start offline training on a fixed dataset")
@@ -398,7 +382,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
)
if step % cfg.training.log_freq == 0:
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline=True)
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline)
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
# so we pass in step + 1.
@@ -406,9 +390,41 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
step += 1
logging.info("End of offline training")
if cfg.training.online_steps == 0:
if cfg.training.eval_freq > 0:
eval_env.close()
return
# create an env dedicated to online episodes collection from policy rollout
online_training_env = make_env(cfg, n_envs=1)
# create an empty online dataset similar to offline dataset
online_dataset = deepcopy(offline_dataset)
online_dataset.hf_dataset = {}
online_dataset.episode_data_index = {}
# create dataloader for online training
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
weights = [1.0] * len(concat_dataset)
sampler = torch.utils.data.WeightedRandomSampler(
weights, num_samples=len(concat_dataset), replacement=True
)
dataloader = torch.utils.data.DataLoader(
concat_dataset,
num_workers=4,
batch_size=cfg.training.batch_size,
sampler=sampler,
pin_memory=device.type != "cpu",
drop_last=False,
)
logging.info("End of online training")
if cfg.training.eval_freq > 0:
eval_env.close()
logging.info("End of training")
online_training_env.close()
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")

View File

@@ -224,8 +224,7 @@ def main():
help=(
"Mode of viewing between 'local' or 'distant'. "
"'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
"'distant' creates a server on the distant machine where the data is stored. "
"Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
"'distant' creates a server on the distant machine where the data is stored. Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
),
)
parser.add_argument(
@@ -246,8 +245,8 @@ def main():
default=0,
help=(
"Save a .rrd file in the directory provided by `--output-dir`. "
"It also deactivates the spawning of a viewer. "
"Visualize the data by running `rerun path/to/file.rrd` on your local machine."
"It also deactivates the spawning of a viewer. ",
"Visualize the data by running `rerun path/to/file.rrd` on your local machine.",
),
)
parser.add_argument(

163
poetry.lock generated
View File

@@ -444,63 +444,63 @@ files = [
[[package]]
name = "coverage"
version = "7.5.3"
version = "7.5.1"
description = "Code coverage measurement for Python"
optional = true
python-versions = ">=3.8"
files = [
{file = "coverage-7.5.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a6519d917abb15e12380406d721e37613e2a67d166f9fb7e5a8ce0375744cd45"},
{file = "coverage-7.5.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:aea7da970f1feccf48be7335f8b2ca64baf9b589d79e05b9397a06696ce1a1ec"},
{file = "coverage-7.5.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:923b7b1c717bd0f0f92d862d1ff51d9b2b55dbbd133e05680204465f454bb286"},
{file = "coverage-7.5.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62bda40da1e68898186f274f832ef3e759ce929da9a9fd9fcf265956de269dbc"},
{file = "coverage-7.5.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8b7339180d00de83e930358223c617cc343dd08e1aa5ec7b06c3a121aec4e1d"},
{file = "coverage-7.5.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:25a5caf742c6195e08002d3b6c2dd6947e50efc5fc2c2205f61ecb47592d2d83"},
{file = "coverage-7.5.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:05ac5f60faa0c704c0f7e6a5cbfd6f02101ed05e0aee4d2822637a9e672c998d"},
{file = "coverage-7.5.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:239a4e75e09c2b12ea478d28815acf83334d32e722e7433471fbf641c606344c"},
{file = "coverage-7.5.3-cp310-cp310-win32.whl", hash = "sha256:a5812840d1d00eafae6585aba38021f90a705a25b8216ec7f66aebe5b619fb84"},
{file = "coverage-7.5.3-cp310-cp310-win_amd64.whl", hash = "sha256:33ca90a0eb29225f195e30684ba4a6db05dbef03c2ccd50b9077714c48153cac"},
{file = "coverage-7.5.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f81bc26d609bf0fbc622c7122ba6307993c83c795d2d6f6f6fd8c000a770d974"},
{file = "coverage-7.5.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7cec2af81f9e7569280822be68bd57e51b86d42e59ea30d10ebdbb22d2cb7232"},
{file = "coverage-7.5.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:55f689f846661e3f26efa535071775d0483388a1ccfab899df72924805e9e7cd"},
{file = "coverage-7.5.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:50084d3516aa263791198913a17354bd1dc627d3c1639209640b9cac3fef5807"},
{file = "coverage-7.5.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:341dd8f61c26337c37988345ca5c8ccabeff33093a26953a1ac72e7d0103c4fb"},
{file = "coverage-7.5.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ab0b028165eea880af12f66086694768f2c3139b2c31ad5e032c8edbafca6ffc"},
{file = "coverage-7.5.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:5bc5a8c87714b0c67cfeb4c7caa82b2d71e8864d1a46aa990b5588fa953673b8"},
{file = "coverage-7.5.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:38a3b98dae8a7c9057bd91fbf3415c05e700a5114c5f1b5b0ea5f8f429ba6614"},
{file = "coverage-7.5.3-cp311-cp311-win32.whl", hash = "sha256:fcf7d1d6f5da887ca04302db8e0e0cf56ce9a5e05f202720e49b3e8157ddb9a9"},
{file = "coverage-7.5.3-cp311-cp311-win_amd64.whl", hash = "sha256:8c836309931839cca658a78a888dab9676b5c988d0dd34ca247f5f3e679f4e7a"},
{file = "coverage-7.5.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:296a7d9bbc598e8744c00f7a6cecf1da9b30ae9ad51c566291ff1314e6cbbed8"},
{file = "coverage-7.5.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:34d6d21d8795a97b14d503dcaf74226ae51eb1f2bd41015d3ef332a24d0a17b3"},
{file = "coverage-7.5.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e317953bb4c074c06c798a11dbdd2cf9979dbcaa8ccc0fa4701d80042d4ebf1"},
{file = "coverage-7.5.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:705f3d7c2b098c40f5b81790a5fedb274113373d4d1a69e65f8b68b0cc26f6db"},
{file = "coverage-7.5.3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1196e13c45e327d6cd0b6e471530a1882f1017eb83c6229fc613cd1a11b53cd"},
{file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:015eddc5ccd5364dcb902eaecf9515636806fa1e0d5bef5769d06d0f31b54523"},
{file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:fd27d8b49e574e50caa65196d908f80e4dff64d7e592d0c59788b45aad7e8b35"},
{file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:33fc65740267222fc02975c061eb7167185fef4cc8f2770267ee8bf7d6a42f84"},
{file = "coverage-7.5.3-cp312-cp312-win32.whl", hash = "sha256:7b2a19e13dfb5c8e145c7a6ea959485ee8e2204699903c88c7d25283584bfc08"},
{file = "coverage-7.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:0bbddc54bbacfc09b3edaec644d4ac90c08ee8ed4844b0f86227dcda2d428fcb"},
{file = "coverage-7.5.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f78300789a708ac1f17e134593f577407d52d0417305435b134805c4fb135adb"},
{file = "coverage-7.5.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b368e1aee1b9b75757942d44d7598dcd22a9dbb126affcbba82d15917f0cc155"},
{file = "coverage-7.5.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f836c174c3a7f639bded48ec913f348c4761cbf49de4a20a956d3431a7c9cb24"},
{file = "coverage-7.5.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:244f509f126dc71369393ce5fea17c0592c40ee44e607b6d855e9c4ac57aac98"},
{file = "coverage-7.5.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c4c2872b3c91f9baa836147ca33650dc5c172e9273c808c3c3199c75490e709d"},
{file = "coverage-7.5.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:dd4b3355b01273a56b20c219e74e7549e14370b31a4ffe42706a8cda91f19f6d"},
{file = "coverage-7.5.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:f542287b1489c7a860d43a7d8883e27ca62ab84ca53c965d11dac1d3a1fab7ce"},
{file = "coverage-7.5.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:75e3f4e86804023e991096b29e147e635f5e2568f77883a1e6eed74512659ab0"},
{file = "coverage-7.5.3-cp38-cp38-win32.whl", hash = "sha256:c59d2ad092dc0551d9f79d9d44d005c945ba95832a6798f98f9216ede3d5f485"},
{file = "coverage-7.5.3-cp38-cp38-win_amd64.whl", hash = "sha256:fa21a04112c59ad54f69d80e376f7f9d0f5f9123ab87ecd18fbb9ec3a2beed56"},
{file = "coverage-7.5.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f5102a92855d518b0996eb197772f5ac2a527c0ec617124ad5242a3af5e25f85"},
{file = "coverage-7.5.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d1da0a2e3b37b745a2b2a678a4c796462cf753aebf94edcc87dcc6b8641eae31"},
{file = "coverage-7.5.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8383a6c8cefba1b7cecc0149415046b6fc38836295bc4c84e820872eb5478b3d"},
{file = "coverage-7.5.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9aad68c3f2566dfae84bf46295a79e79d904e1c21ccfc66de88cd446f8686341"},
{file = "coverage-7.5.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e079c9ec772fedbade9d7ebc36202a1d9ef7291bc9b3a024ca395c4d52853d7"},
{file = "coverage-7.5.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:bde997cac85fcac227b27d4fb2c7608a2c5f6558469b0eb704c5726ae49e1c52"},
{file = "coverage-7.5.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:990fb20b32990b2ce2c5f974c3e738c9358b2735bc05075d50a6f36721b8f303"},
{file = "coverage-7.5.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3d5a67f0da401e105753d474369ab034c7bae51a4c31c77d94030d59e41df5bd"},
{file = "coverage-7.5.3-cp39-cp39-win32.whl", hash = "sha256:e08c470c2eb01977d221fd87495b44867a56d4d594f43739a8028f8646a51e0d"},
{file = "coverage-7.5.3-cp39-cp39-win_amd64.whl", hash = "sha256:1d2a830ade66d3563bb61d1e3c77c8def97b30ed91e166c67d0632c018f380f0"},
{file = "coverage-7.5.3-pp38.pp39.pp310-none-any.whl", hash = "sha256:3538d8fb1ee9bdd2e2692b3b18c22bb1c19ffbefd06880f5ac496e42d7bb3884"},
{file = "coverage-7.5.3.tar.gz", hash = "sha256:04aefca5190d1dc7a53a4c1a5a7f8568811306d7a8ee231c42fb69215571944f"},
{file = "coverage-7.5.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0884920835a033b78d1c73b6d3bbcda8161a900f38a488829a83982925f6c2e"},
{file = "coverage-7.5.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:39afcd3d4339329c5f58de48a52f6e4e50f6578dd6099961cf22228feb25f38f"},
{file = "coverage-7.5.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a7b0ceee8147444347da6a66be737c9d78f3353b0681715b668b72e79203e4a"},
{file = "coverage-7.5.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a9ca3f2fae0088c3c71d743d85404cec8df9be818a005ea065495bedc33da35"},
{file = "coverage-7.5.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd215c0c7d7aab005221608a3c2b46f58c0285a819565887ee0b718c052aa4e"},
{file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:4bf0655ab60d754491004a5efd7f9cccefcc1081a74c9ef2da4735d6ee4a6223"},
{file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:61c4bf1ba021817de12b813338c9be9f0ad5b1e781b9b340a6d29fc13e7c1b5e"},
{file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:db66fc317a046556a96b453a58eced5024af4582a8dbdc0c23ca4dbc0d5b3146"},
{file = "coverage-7.5.1-cp310-cp310-win32.whl", hash = "sha256:b016ea6b959d3b9556cb401c55a37547135a587db0115635a443b2ce8f1c7228"},
{file = "coverage-7.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:df4e745a81c110e7446b1cc8131bf986157770fa405fe90e15e850aaf7619bc8"},
{file = "coverage-7.5.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:796a79f63eca8814ca3317a1ea443645c9ff0d18b188de470ed7ccd45ae79428"},
{file = "coverage-7.5.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4fc84a37bfd98db31beae3c2748811a3fa72bf2007ff7902f68746d9757f3746"},
{file = "coverage-7.5.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6175d1a0559986c6ee3f7fccfc4a90ecd12ba0a383dcc2da30c2b9918d67d8a3"},
{file = "coverage-7.5.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fc81d5878cd6274ce971e0a3a18a8803c3fe25457165314271cf78e3aae3aa2"},
{file = "coverage-7.5.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:556cf1a7cbc8028cb60e1ff0be806be2eded2daf8129b8811c63e2b9a6c43bca"},
{file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:9981706d300c18d8b220995ad22627647be11a4276721c10911e0e9fa44c83e8"},
{file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d7fed867ee50edf1a0b4a11e8e5d0895150e572af1cd6d315d557758bfa9c057"},
{file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ef48e2707fb320c8f139424a596f5b69955a85b178f15af261bab871873bb987"},
{file = "coverage-7.5.1-cp311-cp311-win32.whl", hash = "sha256:9314d5678dcc665330df5b69c1e726a0e49b27df0461c08ca12674bcc19ef136"},
{file = "coverage-7.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:5fa567e99765fe98f4e7d7394ce623e794d7cabb170f2ca2ac5a4174437e90dd"},
{file = "coverage-7.5.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b6cf3764c030e5338e7f61f95bd21147963cf6aa16e09d2f74f1fa52013c1206"},
{file = "coverage-7.5.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ec92012fefebee89a6b9c79bc39051a6cb3891d562b9270ab10ecfdadbc0c34"},
{file = "coverage-7.5.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16db7f26000a07efcf6aea00316f6ac57e7d9a96501e990a36f40c965ec7a95d"},
{file = "coverage-7.5.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:beccf7b8a10b09c4ae543582c1319c6df47d78fd732f854ac68d518ee1fb97fa"},
{file = "coverage-7.5.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8748731ad392d736cc9ccac03c9845b13bb07d020a33423fa5b3a36521ac6e4e"},
{file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7352b9161b33fd0b643ccd1f21f3a3908daaddf414f1c6cb9d3a2fd618bf2572"},
{file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:7a588d39e0925f6a2bff87154752481273cdb1736270642aeb3635cb9b4cad07"},
{file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:68f962d9b72ce69ea8621f57551b2fa9c70509af757ee3b8105d4f51b92b41a7"},
{file = "coverage-7.5.1-cp312-cp312-win32.whl", hash = "sha256:f152cbf5b88aaeb836127d920dd0f5e7edff5a66f10c079157306c4343d86c19"},
{file = "coverage-7.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:5a5740d1fb60ddf268a3811bcd353de34eb56dc24e8f52a7f05ee513b2d4f596"},
{file = "coverage-7.5.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e2213def81a50519d7cc56ed643c9e93e0247f5bbe0d1247d15fa520814a7cd7"},
{file = "coverage-7.5.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5037f8fcc2a95b1f0e80585bd9d1ec31068a9bcb157d9750a172836e98bc7a90"},
{file = "coverage-7.5.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3721c2c9e4c4953a41a26c14f4cef64330392a6d2d675c8b1db3b645e31f0e"},
{file = "coverage-7.5.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca498687ca46a62ae590253fba634a1fe9836bc56f626852fb2720f334c9e4e5"},
{file = "coverage-7.5.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0cdcbc320b14c3e5877ee79e649677cb7d89ef588852e9583e6b24c2e5072661"},
{file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:57e0204b5b745594e5bc14b9b50006da722827f0b8c776949f1135677e88d0b8"},
{file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:8fe7502616b67b234482c3ce276ff26f39ffe88adca2acf0261df4b8454668b4"},
{file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:9e78295f4144f9dacfed4f92935fbe1780021247c2fabf73a819b17f0ccfff8d"},
{file = "coverage-7.5.1-cp38-cp38-win32.whl", hash = "sha256:1434e088b41594baa71188a17533083eabf5609e8e72f16ce8c186001e6b8c41"},
{file = "coverage-7.5.1-cp38-cp38-win_amd64.whl", hash = "sha256:0646599e9b139988b63704d704af8e8df7fa4cbc4a1f33df69d97f36cb0a38de"},
{file = "coverage-7.5.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4cc37def103a2725bc672f84bd939a6fe4522310503207aae4d56351644682f1"},
{file = "coverage-7.5.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fc0b4d8bfeabd25ea75e94632f5b6e047eef8adaed0c2161ada1e922e7f7cece"},
{file = "coverage-7.5.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d0a0f5e06881ecedfe6f3dd2f56dcb057b6dbeb3327fd32d4b12854df36bf26"},
{file = "coverage-7.5.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9735317685ba6ec7e3754798c8871c2f49aa5e687cc794a0b1d284b2389d1bd5"},
{file = "coverage-7.5.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d21918e9ef11edf36764b93101e2ae8cc82aa5efdc7c5a4e9c6c35a48496d601"},
{file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c3e757949f268364b96ca894b4c342b41dc6f8f8b66c37878aacef5930db61be"},
{file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:79afb6197e2f7f60c4824dd4b2d4c2ec5801ceb6ba9ce5d2c3080e5660d51a4f"},
{file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d1d0d98d95dd18fe29dc66808e1accf59f037d5716f86a501fc0256455219668"},
{file = "coverage-7.5.1-cp39-cp39-win32.whl", hash = "sha256:1cc0fe9b0b3a8364093c53b0b4c0c2dd4bb23acbec4c9240b5f284095ccf7981"},
{file = "coverage-7.5.1-cp39-cp39-win_amd64.whl", hash = "sha256:dde0070c40ea8bb3641e811c1cfbf18e265d024deff6de52c5950677a8fb1e0f"},
{file = "coverage-7.5.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:6537e7c10cc47c595828b8a8be04c72144725c383c4702703ff4e42e44577312"},
{file = "coverage-7.5.1.tar.gz", hash = "sha256:54de9ef3a9da981f7af93eafde4ede199e0846cd819eb27c88e2b712aae9708c"},
]
[package.dependencies]
@@ -1104,7 +1104,7 @@ pyarrow = ">=12.0.0"
type = "git"
url = "https://github.com/dora-rs/dora-lerobot.git"
reference = "HEAD"
resolved_reference = "ed0c00a4fdc6ec856c9842551acd7dc7ee776f79"
resolved_reference = "1c6c2a401c3a2967d41444be6286ca9a28893abf"
subdirectory = "gym_dora"
[[package]]
@@ -1310,13 +1310,13 @@ files = [
[[package]]
name = "huggingface-hub"
version = "0.23.2"
version = "0.23.1"
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
optional = false
python-versions = ">=3.8.0"
files = [
{file = "huggingface_hub-0.23.2-py3-none-any.whl", hash = "sha256:48727a16e704d409c4bb5913613308499664f22a99743435dc3a13b23c485827"},
{file = "huggingface_hub-0.23.2.tar.gz", hash = "sha256:f6829b62d5fdecb452a76fdbec620cba4c1573655a8d710c1df71735fd9edbd2"},
{file = "huggingface_hub-0.23.1-py3-none-any.whl", hash = "sha256:720a5bffd2b1b449deb793da8b0df7a9390a7e238534d5a08c9fbcdecb1dd3cb"},
{file = "huggingface_hub-0.23.1.tar.gz", hash = "sha256:4f62dbf6ae94f400c6d3419485e52bce510591432a5248a65d0cb72e4d479eb4"},
]
[package.dependencies]
@@ -2102,15 +2102,18 @@ test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"]
[[package]]
name = "nodeenv"
version = "1.9.0"
version = "1.8.0"
description = "Node.js virtual environment builder"
optional = true
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*"
files = [
{file = "nodeenv-1.9.0-py2.py3-none-any.whl", hash = "sha256:508ecec98f9f3330b636d4448c0f1a56fc68017c68f1e7857ebc52acf0eb879a"},
{file = "nodeenv-1.9.0.tar.gz", hash = "sha256:07f144e90dae547bf0d4ee8da0ee42664a42a04e02ed68e06324348dafe4bdb1"},
{file = "nodeenv-1.8.0-py2.py3-none-any.whl", hash = "sha256:df865724bb3c3adc86b3876fa209771517b0cfe596beff01a92700e0e8be4cec"},
{file = "nodeenv-1.8.0.tar.gz", hash = "sha256:d51e0c37e64fbf47d017feac3145cdbb58836d7eee8c6f6d3b6880c5456227d2"},
]
[package.dependencies]
setuptools = "*"
[[package]]
name = "numba"
version = "0.59.1"
@@ -3228,13 +3231,13 @@ files = [
[[package]]
name = "requests"
version = "2.32.3"
version = "2.32.2"
description = "Python HTTP for Humans."
optional = false
python-versions = ">=3.8"
files = [
{file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"},
{file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"},
{file = "requests-2.32.2-py3-none-any.whl", hash = "sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c"},
{file = "requests-2.32.2.tar.gz", hash = "sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289"},
]
[package.dependencies]
@@ -3250,16 +3253,16 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]]
name = "rerun-sdk"
version = "0.16.1"
version = "0.16.0"
description = "The Rerun Logging SDK"
optional = false
python-versions = "<3.13,>=3.8"
files = [
{file = "rerun_sdk-0.16.1-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:170c6976634008611753e10dfef8cdc395ce8180e634c169e7c61cef2f89a277"},
{file = "rerun_sdk-0.16.1-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:c9a76eab7eb5559276737dad655200e9350df0837158dbc5a896970ab4201454"},
{file = "rerun_sdk-0.16.1-cp38-abi3-manylinux_2_31_aarch64.whl", hash = "sha256:4d6436752d57e8b8038489a0e7e37f0c760b088e96db5fb81667d3a376d63fea"},
{file = "rerun_sdk-0.16.1-cp38-abi3-manylinux_2_31_x86_64.whl", hash = "sha256:37b7b47948471873e84f224b16f417a94a91c7cbd6c72c68281eeff1ba414b8f"},
{file = "rerun_sdk-0.16.1-cp38-abi3-win_amd64.whl", hash = "sha256:be88799c8afdf68eafa99e64e2e4f0a484e187e017a180219abbe6bb988acd4e"},
{file = "rerun_sdk-0.16.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:1cc6dc66d089e296f945dc238301889efb61dd6d338b5d00f76981cf7aed0a74"},
{file = "rerun_sdk-0.16.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:faf231897655e46eb975695df2b0ace07db362d697e697f9a3dff52f81c0dc5d"},
{file = "rerun_sdk-0.16.0-cp38-abi3-manylinux_2_31_aarch64.whl", hash = "sha256:860a6394380d3e9b9e48bf34423bd56dda54d5b0158d2ae0e433698659b34198"},
{file = "rerun_sdk-0.16.0-cp38-abi3-manylinux_2_31_x86_64.whl", hash = "sha256:5b8d1476f73a3ad1a5d3f21b61c633f3ab62aa80fa0b049f5ad10bf1227681ab"},
{file = "rerun_sdk-0.16.0-cp38-abi3-win_amd64.whl", hash = "sha256:aff0051a263b8c3067243c0126d319845baf4fe640899f04aeef7daf151f35e4"},
]
[package.dependencies]
@@ -3736,17 +3739,17 @@ files = [
[[package]]
name = "sympy"
version = "1.12.1"
version = "1.12"
description = "Computer algebra system (CAS) in Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "sympy-1.12.1-py3-none-any.whl", hash = "sha256:9b2cbc7f1a640289430e13d2a56f02f867a1da0190f2f99d8968c2f74da0e515"},
{file = "sympy-1.12.1.tar.gz", hash = "sha256:2877b03f998cd8c08f07cd0de5b767119cd3ef40d09f41c30d722f6686b0fb88"},
{file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"},
{file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"},
]
[package.dependencies]
mpmath = ">=1.1.0,<1.4.0"
mpmath = ">=0.19"
[[package]]
name = "tbb"
@@ -4260,13 +4263,13 @@ multidict = ">=4.0"
[[package]]
name = "zarr"
version = "2.18.2"
version = "2.18.1"
description = "An implementation of chunked, compressed, N-dimensional arrays for Python"
optional = false
python-versions = ">=3.9"
files = [
{file = "zarr-2.18.2-py3-none-any.whl", hash = "sha256:a638754902f97efa99b406083fdc807a0e2ccf12a949117389d2a4ba9b05df38"},
{file = "zarr-2.18.2.tar.gz", hash = "sha256:9bb393b8a0a38fb121dbb913b047d75db28de9890f6d644a217a73cf4ae74f47"},
{file = "zarr-2.18.1-py3-none-any.whl", hash = "sha256:a1770d194eec4ec0a41a01295a6f724e1c3471d704d3aca906d3b3a7f8830245"},
{file = "zarr-2.18.1.tar.gz", hash = "sha256:28c360ed123e606c425a694a83300227a907cb86a995fc9eef620ecafbe5f92d"},
]
[package.dependencies]
@@ -4281,13 +4284,13 @@ jupyter = ["ipytree (>=0.2.2)", "ipywidgets (>=8.0.0)", "notebook"]
[[package]]
name = "zipp"
version = "3.19.0"
version = "3.18.2"
description = "Backport of pathlib-compatible object wrapper for zip files"
optional = false
python-versions = ">=3.8"
files = [
{file = "zipp-3.19.0-py3-none-any.whl", hash = "sha256:96dc6ad62f1441bcaccef23b274ec471518daf4fbbc580341204936a5a3dddec"},
{file = "zipp-3.19.0.tar.gz", hash = "sha256:952df858fb3164426c976d9338d3961e8e8b3758e2e059e0f754b8c4262625ee"},
{file = "zipp-3.18.2-py3-none-any.whl", hash = "sha256:dce197b859eb796242b0622af1b8beb0a722d52aa2f57133ead08edd5bf5374e"},
{file = "zipp-3.18.2.tar.gz", hash = "sha256:6278d9ddbcfb1f1089a88fde84481528b07b0e10474e09dcfe53dad4069fa059"},
]
[package.extras]

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2fff6294b94cf42d4dd1249dcc5c3b0269d6d9c697f894e61b867d7ab81a94e4
size 5104
oid sha256:ebd21273f6048b66c806f92035352843a9069908b3296863fd55d34cf71cd0ef
size 51248

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4aa23e51607604a18b70fa42edbbe1af34f119d985628fc27cc1bbb0efbc8901
oid sha256:b9bbf951891077320a5da27e77ddb580a6e833e8d3162b62a2f887a1989585cc
size 31688

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6fd368406c93cb562a69ff11cf7adf34a4b223507dcb2b9e9b8f44ee1036988a
oid sha256:d4070bd1f1cd8c72bc2daf628088e42b8ef113f6df0bfd9e91be052bc90038c3
size 68

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5663ee79a13bb70a1604b887dd21bf89d18482287442419c6cc6c5bf0e753e99
oid sha256:42f92239223bb4df32d5c3016bc67450159f1285a7ab046307b645f699ccc34e
size 34928

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fb1a45463efd860af2ca22c16c77d55a18bd96fef080ae77978845a2f22ef716
size 5104
oid sha256:52f85d6262ad1dd0b66578b25829fed96aaaca3c7458cb73ac75111350d17fcf
size 51248

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:aa5a43e22f01d8e2f8d19f31753608794f1edbd74aaf71660091ab80ea58dc9b
oid sha256:5ba7c910618f0f3ca69f82f3d70c880d2b2e432456524a2a63dfd5c50efa45f0
size 30808

View File

@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:54d1f75cf67a7b1d7a7c6865ecb9b1cc86a2f032d1890245f8996789ab6e0df6
oid sha256:53ad410f43855254438790f54aa7c895a052776acdd922906ae430684f659b53
size 33608

View File

@@ -77,7 +77,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
batch = next(iter(dataloader))
obs = {}
for k in batch:
if k.startswith("observation"):
if "observation" in k:
obs[k] = batch[k]
if "n_action_steps" in cfg.policy:
@@ -115,8 +115,8 @@ if __name__ == "__main__":
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
),
("aloha", "act", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real", []),
("dora_aloha_real", "act_real_no_state", []),
]
for env, policy, extra_overrides in env_policies:
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)

View File

@@ -16,7 +16,6 @@
import json
import logging
from copy import deepcopy
from itertools import chain
from pathlib import Path
import einops
@@ -26,34 +25,26 @@ from datasets import Dataset
from safetensors.torch import load_file
import lerobot
from lerobot.common.datasets.compute_stats import (
aggregate_stats,
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset,
)
from lerobot.common.datasets.push_dataset_to_hub.compute_stats import (
compute_stats,
get_stats_einops_patterns,
)
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
from lerobot.common.datasets.utils import (
flatten_dict,
hf_transform_to_torch,
load_previous_and_future_frames,
unflatten_dict,
)
from lerobot.common.utils.utils import init_hydra_config, seeded_context
from lerobot.common.utils.utils import init_hydra_config
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE
@pytest.mark.parametrize(
"env_name, repo_id, policy_name",
lerobot.env_dataset_policy_triplets
+ [("aloha", ["lerobot/aloha_sim_insertion_human", "lerobot/aloha_sim_transfer_cube_human"], "act")],
)
@pytest.mark.parametrize("env_name, repo_id, policy_name", lerobot.env_dataset_policy_triplets)
def test_factory(env_name, repo_id, policy_name):
"""
Tests that:
- we can create a dataset with the factory.
- for a commonly used set of data keys, the data dimensions are correct.
"""
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=[
@@ -114,39 +105,6 @@ def test_factory(env_name, repo_id, policy_name):
assert key in item, f"{key}"
# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds.
def test_multilerobotdataset_frames():
"""Check that all dataset frames are incorporated."""
# Note: use the image variants of the dataset to make the test approx 3x faster.
# Note: We really do need three repo_ids here as at some point this caught an issue with the chaining
# logic that wouldn't be caught with two repo IDs.
repo_ids = [
"lerobot/aloha_sim_insertion_human_image",
"lerobot/aloha_sim_transfer_cube_human_image",
"lerobot/aloha_sim_insertion_scripted_image",
]
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
dataset = MultiLeRobotDataset(repo_ids)
assert len(dataset) == sum(len(d) for d in sub_datasets)
assert dataset.num_samples == sum(d.num_samples for d in sub_datasets)
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
# Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and
# check they match.
expected_dataset_indices = []
for i, sub_dataset in enumerate(sub_datasets):
expected_dataset_indices.extend([i] * len(sub_dataset))
for expected_dataset_index, sub_dataset_item, dataset_item in zip(
expected_dataset_indices, chain(*sub_datasets), dataset, strict=True
):
dataset_index = dataset_item.pop("dataset_index")
assert dataset_index == expected_dataset_index
assert sub_dataset_item.keys() == dataset_item.keys()
for k in sub_dataset_item:
assert torch.equal(sub_dataset_item[k], dataset_item[k])
def test_compute_stats_on_xarm():
"""Check that the statistics are computed correctly according to the stats_patterns property.
@@ -357,31 +315,3 @@ def test_backward_compatibility(repo_id):
# i = dataset.episode_data_index["to"][-1].item()
# load_and_compare(i - 2)
# load_and_compare(i - 1)
def test_aggregate_stats():
"""Makes 3 basic datasets and checks that aggregate stats are computed correctly."""
with seeded_context(0):
data_a = torch.rand(30, dtype=torch.float32)
data_b = torch.rand(20, dtype=torch.float32)
data_c = torch.rand(20, dtype=torch.float32)
hf_dataset_1 = Dataset.from_dict(
{"a": data_a[:10], "b": data_b[:10], "c": data_c[:10], "index": torch.arange(10)}
)
hf_dataset_1.set_transform(hf_transform_to_torch)
hf_dataset_2 = Dataset.from_dict({"a": data_a[10:20], "b": data_b[10:], "index": torch.arange(10)})
hf_dataset_2.set_transform(hf_transform_to_torch)
hf_dataset_3 = Dataset.from_dict({"a": data_a[20:], "c": data_c[10:], "index": torch.arange(10)})
hf_dataset_3.set_transform(hf_transform_to_torch)
dataset_1 = LeRobotDataset.from_preloaded("d1", hf_dataset=hf_dataset_1)
dataset_1.stats = compute_stats(dataset_1, batch_size=len(hf_dataset_1), num_workers=0)
dataset_2 = LeRobotDataset.from_preloaded("d2", hf_dataset=hf_dataset_2)
dataset_2.stats = compute_stats(dataset_2, batch_size=len(hf_dataset_2), num_workers=0)
dataset_3 = LeRobotDataset.from_preloaded("d3", hf_dataset=hf_dataset_3)
dataset_3.stats = compute_stats(dataset_3, batch_size=len(hf_dataset_3), num_workers=0)
stats = aggregate_stats([dataset_1, dataset_2, dataset_3])
for data_key, data in zip(["a", "b", "c"], [data_a, data_b, data_c], strict=True):
for agg_fn in ["mean", "min", "max"]:
assert torch.allclose(stats[data_key][agg_fn], einops.reduce(data, "n -> 1", agg_fn))
assert torch.allclose(stats[data_key]["std"], torch.std(data, correction=0))

View File

@@ -29,8 +29,8 @@ def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> s
return text
def _run_script(path, args=None):
subprocess.run([sys.executable, path] + args if args is not None else [], check=True)
def _run_script(path):
subprocess.run([sys.executable, path], check=True)
def _read_file(path):
@@ -126,22 +126,3 @@ def test_examples_basic2_basic3_advanced1():
# Restore stdout to its original state
sys.stdout = sys.__stdout__
assert "Average loss on validation set" in printed_output
def test_real_world_recording():
path = "examples/real_robot_example/record_training_data.py"
_run_script(
path,
[
"--data_dir",
"outputs/examples",
"--repo-id",
"real_world_debug",
"--num-episodes",
"2",
"--num-frames",
"10",
"--mock-robot",
],
)
assert Path("outputs/examples/real_world_debug/video/episode_0.mp4").exists()

View File

@@ -86,9 +86,6 @@ def test_policy(env_name, policy_name, extra_overrides):
- Updating the policy.
- Using the policy to select actions at inference time.
- Test the action can be applied to the policy
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
and for now we add tests as we see fit.
"""
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
@@ -140,7 +137,7 @@ def test_policy(env_name, policy_name, extra_overrides):
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
num_workers=4,
batch_size=2,
shuffle=True,
pin_memory=DEVICE != "cpu",

View File

@@ -1,90 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datasets import Dataset
from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch,
)
def test_drop_n_first_frames():
dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
"index": [0, 1, 2, 3, 4, 5],
"episode_index": [0, 0, 1, 2, 2, 2],
},
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, drop_n_first_frames=1)
assert sampler.indices == [1, 4, 5]
assert len(sampler) == 3
assert list(sampler) == [1, 4, 5]
def test_drop_n_last_frames():
dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
"index": [0, 1, 2, 3, 4, 5],
"episode_index": [0, 0, 1, 2, 2, 2],
},
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, drop_n_last_frames=1)
assert sampler.indices == [0, 3, 4]
assert len(sampler) == 3
assert list(sampler) == [0, 3, 4]
def test_episode_indices_to_use():
dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
"index": [0, 1, 2, 3, 4, 5],
"episode_index": [0, 0, 1, 2, 2, 2],
},
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, episode_indices_to_use=[0, 2])
assert sampler.indices == [0, 1, 3, 4, 5]
assert len(sampler) == 5
assert list(sampler) == [0, 1, 3, 4, 5]
def test_shuffle():
dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
"index": [0, 1, 2, 3, 4, 5],
"episode_index": [0, 0, 1, 2, 2, 2],
},
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, shuffle=False)
assert sampler.indices == [0, 1, 2, 3, 4, 5]
assert len(sampler) == 6
assert list(sampler) == [0, 1, 2, 3, 4, 5]
sampler = EpisodeAwareSampler(episode_data_index, shuffle=True)
assert sampler.indices == [0, 1, 2, 3, 4, 5]
assert len(sampler) == 6
assert set(sampler) == {0, 1, 2, 3, 4, 5}