forked from tangger/lerobot
Compare commits
27 Commits
test/robot
...
recovered-
| Author | SHA1 | Date | |
|---|---|---|---|
| ef45ea9649 | |||
| bc351a0134 | |||
| 68986f6fc0 | |||
| 2f124e34de | |||
| c28e774234 | |||
| 80b1a97e4c | |||
| f4fec8f51c | |||
| f4f82c916f | |||
| ecbe154709 | |||
| d00c154db9 | |||
| 55f284b306 | |||
| cf8df17d3a | |||
| e079566597 | |||
| 83d6419d70 | |||
| a0ec9e1cb1 | |||
| 3eede4447d | |||
| 9c6a7d9701 | |||
| 7b201773f3 | |||
|
|
bfd26eef5a | ||
|
|
1537d0ab90 | ||
|
|
2be7f3a3ff | ||
|
|
0cf864870c | ||
|
|
1786916a16 | ||
|
|
0507ad4f68 | ||
|
|
bed90e3a41 | ||
|
|
6163daaaa4 | ||
|
|
8e2a394442 |
24
.github/workflows/build-docker-images.yml
vendored
24
.github/workflows/build-docker-images.yml
vendored
@@ -40,24 +40,24 @@ jobs:
|
||||
git lfs install
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
|
||||
- name: Build and Push CPU
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/lerobot-cpu/Dockerfile
|
||||
@@ -78,24 +78,24 @@ jobs:
|
||||
git lfs install
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
|
||||
- name: Build and Push GPU
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/lerobot-gpu/Dockerfile
|
||||
@@ -110,23 +110,23 @@ jobs:
|
||||
group: aws-general-8-plus
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
|
||||
- name: Build and Push GPU dev
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/lerobot-gpu-dev/Dockerfile
|
||||
|
||||
4
.github/workflows/nightly-tests.yml
vendored
4
.github/workflows/nightly-tests.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
container:
|
||||
image: huggingface/lerobot-cpu:latest
|
||||
image: huggingface/lerobot-cpu:latest # zizmor: ignore[unpinned-images]
|
||||
options: --shm-size "16gb"
|
||||
credentials:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
@@ -60,7 +60,7 @@ jobs:
|
||||
CUDA_VISIBLE_DEVICES: "0"
|
||||
TEST_TYPE: "single_gpu"
|
||||
container:
|
||||
image: huggingface/lerobot-gpu:latest
|
||||
image: huggingface/lerobot-gpu:latest # zizmor: ignore[unpinned-images]
|
||||
options: --gpus all --shm-size "16gb"
|
||||
credentials:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
|
||||
8
.github/workflows/quality.yml
vendored
8
.github/workflows/quality.yml
vendored
@@ -33,12 +33,12 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@7f4fc3e22c37d6ff65e88745f38bd3157c663f7c # v4.9.1
|
||||
with:
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
@@ -64,9 +64,9 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: typos-action
|
||||
uses: crate-ci/typos@v1.29.10
|
||||
uses: crate-ci/typos@db35ee91e80fbb447f33b0e5fbddb24d2a1a884f # v1.29.10
|
||||
|
||||
8
.github/workflows/test-docker-build.yml
vendored
8
.github/workflows/test-docker-build.yml
vendored
@@ -35,7 +35,7 @@ jobs:
|
||||
matrix: ${{ steps.set-matrix.outputs.matrix }}
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
@@ -64,17 +64,17 @@ jobs:
|
||||
docker-file: ${{ fromJson(needs.get_changed_files.outputs.matrix) }}
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0
|
||||
with:
|
||||
cache-binary: false
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Build Docker image
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0
|
||||
with:
|
||||
file: ${{ matrix.docker-file }}
|
||||
context: .
|
||||
|
||||
12
.github/workflows/test.yml
vendored
12
.github/workflows/test.yml
vendored
@@ -50,7 +50,7 @@ jobs:
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
persist-credentials: false
|
||||
@@ -62,7 +62,7 @@ jobs:
|
||||
sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev
|
||||
|
||||
- name: Install uv and python
|
||||
uses: astral-sh/setup-uv@v5
|
||||
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5.4.2
|
||||
with:
|
||||
enable-cache: true
|
||||
version: ${{ env.UV_VERSION }}
|
||||
@@ -85,7 +85,7 @@ jobs:
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
persist-credentials: false
|
||||
@@ -94,7 +94,7 @@ jobs:
|
||||
run: sudo apt-get update && sudo apt-get install -y ffmpeg
|
||||
|
||||
- name: Install uv and python
|
||||
uses: astral-sh/setup-uv@v5
|
||||
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5.4.2
|
||||
with:
|
||||
enable-cache: true
|
||||
version: ${{ env.UV_VERSION }}
|
||||
@@ -117,7 +117,7 @@ jobs:
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
persist-credentials: false
|
||||
@@ -129,7 +129,7 @@ jobs:
|
||||
sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev
|
||||
|
||||
- name: Install uv and python
|
||||
uses: astral-sh/setup-uv@v5
|
||||
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5.4.2
|
||||
with:
|
||||
enable-cache: true
|
||||
version: ${{ env.UV_VERSION }}
|
||||
|
||||
4
.github/workflows/trufflehog.yml
vendored
4
.github/workflows/trufflehog.yml
vendored
@@ -24,12 +24,12 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Secret Scanning
|
||||
uses: trufflesecurity/trufflehog@main
|
||||
uses: trufflesecurity/trufflehog@90694bf9af66e7536abc5824e7a87246dbf933cb # v3.88.35
|
||||
with:
|
||||
extra_args: --only-verified
|
||||
|
||||
@@ -37,18 +37,18 @@ repos:
|
||||
- id: trailing-whitespace
|
||||
|
||||
- repo: https://github.com/adhtruong/mirrors-typos
|
||||
rev: v1.31.1
|
||||
rev: v1.32.0
|
||||
hooks:
|
||||
- id: typos
|
||||
args: [--force-exclude]
|
||||
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.19.1
|
||||
rev: v3.20.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.11.5
|
||||
rev: v0.11.11
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
@@ -57,12 +57,12 @@ repos:
|
||||
|
||||
##### Security #####
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
rev: v8.24.3
|
||||
rev: v8.26.0
|
||||
hooks:
|
||||
- id: gitleaks
|
||||
|
||||
- repo: https://github.com/woodruffw/zizmor-pre-commit
|
||||
rev: v1.5.2
|
||||
rev: v1.8.0
|
||||
hooks:
|
||||
- id: zizmor
|
||||
|
||||
|
||||
@@ -360,7 +360,7 @@ with profile(
|
||||
If you want, you can cite this work with:
|
||||
```bibtex
|
||||
@misc{cadene2024lerobot,
|
||||
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Wolf, Thomas},
|
||||
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascale, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas},
|
||||
title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
|
||||
howpublished = "\url{https://github.com/huggingface/lerobot}",
|
||||
year = {2024}
|
||||
|
||||
@@ -55,7 +55,7 @@ conda install ffmpeg -c conda-forge
|
||||
|
||||
Install 🤗 LeRobot:
|
||||
```bash
|
||||
cd lerobot && pip install ".[feetech]"
|
||||
cd lerobot && pip install -e ".[feetech]"
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
@@ -141,7 +141,7 @@ python lerobot/scripts/configure_motor.py \
|
||||
--ID 1
|
||||
```
|
||||
|
||||
Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).
|
||||
Note: These motors are currently limited. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).
|
||||
|
||||
Then unplug your motor and plug the second motor and set its ID to 2.
|
||||
```bash
|
||||
|
||||
@@ -61,7 +61,7 @@ conda install ffmpeg -c conda-forge
|
||||
|
||||
Install 🤗 LeRobot:
|
||||
```bash
|
||||
cd lerobot && pip install ".[feetech]"
|
||||
cd lerobot && pip install -e ".[feetech]"
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
|
||||
@@ -168,12 +168,7 @@ available_datasets = sorted(
|
||||
)
|
||||
|
||||
# lists all available policies from `lerobot/common/policies`
|
||||
available_policies = [
|
||||
"act",
|
||||
"diffusion",
|
||||
"tdmpc",
|
||||
"vqbet",
|
||||
]
|
||||
available_policies = ["act", "diffusion", "tdmpc", "vqbet"]
|
||||
|
||||
# lists all available robots from `lerobot/common/robot_devices/robots`
|
||||
available_robots = [
|
||||
|
||||
@@ -106,7 +106,7 @@ def worker_process(queue: queue.Queue, num_threads: int):
|
||||
class AsyncImageWriter:
|
||||
"""
|
||||
This class abstract away the initialisation of processes or/and threads to
|
||||
save images on disk asynchrounously, which is critical to control a robot and record data
|
||||
save images on disk asynchronously, which is critical to control a robot and record data
|
||||
at a high frame rate.
|
||||
|
||||
When `num_processes=0`, it creates a threads pool of size `num_threads`.
|
||||
|
||||
@@ -944,7 +944,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def stop_image_writer(self) -> None:
|
||||
"""
|
||||
Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to
|
||||
remove the image_writer in order for the LeRobotDataset object to be pickleable and parallelized.
|
||||
remove the image_writer in order for the LeRobotDataset object to be picklable and parallelized.
|
||||
"""
|
||||
if self.image_writer is not None:
|
||||
self.image_writer.stop()
|
||||
|
||||
@@ -101,7 +101,7 @@ def decode_video_frames_torchvision(
|
||||
keyframes_only = False
|
||||
torchvision.set_video_backend(backend)
|
||||
if backend == "pyav":
|
||||
keyframes_only = True # pyav doesnt support accuracte seek
|
||||
keyframes_only = True # pyav doesn't support accurate seek
|
||||
|
||||
# set a video stream reader
|
||||
# TODO(rcadene): also load audio stream at the same time
|
||||
|
||||
@@ -15,5 +15,6 @@
|
||||
from .act.configuration_act import ACTConfig as ACTConfig
|
||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
|
||||
@@ -27,6 +27,7 @@ from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionC
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
@@ -59,6 +60,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
|
||||
|
||||
return PI0FASTPolicy
|
||||
elif name == "smolvla":
|
||||
from lerobot.common.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
return SmolVLAPolicy
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
@@ -76,6 +81,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi0fast":
|
||||
return PI0FASTConfig(**kwargs)
|
||||
elif policy_type == "smolvla":
|
||||
return SmolVLAConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
|
||||
@@ -357,7 +357,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
if self.config.resize_imgs_with_padding is not None:
|
||||
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
|
||||
|
||||
# Normalize from range [0,1] to [-1,1] as expacted by siglip
|
||||
# Normalize from range [0,1] to [-1,1] as expected by siglip
|
||||
img = img * 2.0 - 1.0
|
||||
|
||||
bsize = img.shape[0]
|
||||
|
||||
@@ -516,7 +516,7 @@ class PI0FAST(nn.Module):
|
||||
interpolate_like_pi=self.config.interpolate_like_pi,
|
||||
)
|
||||
|
||||
# Normalize from range [0,1] to [-1,1] as expacted by siglip
|
||||
# Normalize from range [0,1] to [-1,1] as expected by siglip
|
||||
img = img * 2.0 - 1.0
|
||||
|
||||
bsize = img.shape[0]
|
||||
|
||||
154
lerobot/common/policies/smolvla/configuration_smolvla.py
Normal file
154
lerobot/common/policies/smolvla/configuration_smolvla.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamWConfig
|
||||
from lerobot.common.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("smolvla")
|
||||
@dataclass
|
||||
class SmolVLAConfig(PreTrainedConfig):
|
||||
# Input / output structure.
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 50
|
||||
n_action_steps: int = 50
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
# Shorter state and action vectors will be padded
|
||||
max_state_dim: int = 32
|
||||
max_action_dim: int = 32
|
||||
|
||||
# Image preprocessing
|
||||
resize_imgs_with_padding: tuple[int, int] = (512, 512)
|
||||
|
||||
# Add empty images. Used by smolvla_aloha_sim which adds the empty
|
||||
# left and right wrist cameras in addition to the top camera.
|
||||
empty_cameras: int = 0
|
||||
|
||||
# Converts the joint and gripper values from the standard Aloha space to
|
||||
# the space used by the pi internal runtime which was used to train the base model.
|
||||
adapt_to_pi_aloha: bool = False
|
||||
|
||||
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
|
||||
# Gripper dimensions will remain in absolute values.
|
||||
use_delta_joint_actions_aloha: bool = False
|
||||
|
||||
# Tokenizer
|
||||
tokenizer_max_length: int = 48
|
||||
|
||||
# Decoding
|
||||
num_steps: int = 10
|
||||
|
||||
# Attention utils
|
||||
use_cache: bool = True
|
||||
|
||||
# Finetuning settings
|
||||
freeze_vision_encoder: bool = True
|
||||
train_expert_only: bool = True
|
||||
train_state_proj: bool = True
|
||||
|
||||
# Training presets
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-10
|
||||
optimizer_grad_clip_norm: float = 10
|
||||
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
vlm_model_name: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" # Select the VLM backbone.
|
||||
load_vlm_weights: bool = False # Set to True in case of training the expert from scratch. True when init from pretrained SmolVLA weights
|
||||
|
||||
add_image_special_tokens: bool = False # Whether to use special image tokens around image features.
|
||||
|
||||
attention_mode: str = "cross_attn"
|
||||
|
||||
prefix_length: int = -1
|
||||
|
||||
pad_language_to: str = "longest" # "max_length"
|
||||
|
||||
num_expert_layers: int = -1 # Less or equal to 0 is the default where the action expert has the same number of layers of VLM. Otherwise the expert have less layers.
|
||||
num_vlm_layers: int = 16 # Number of layers used in the VLM (first num_vlm_layers layers)
|
||||
self_attn_every_n_layers: int = 2 # Interleave SA layers each self_attn_every_n_layers
|
||||
expert_width_multiplier: float = 0.75 # The action expert hidden size (wrt to the VLM)
|
||||
|
||||
min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding
|
||||
max_period: float = 4.0
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
"""Input validation (not exhaustive)."""
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||||
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
||||
)
|
||||
if self.use_delta_joint_actions_aloha:
|
||||
raise NotImplementedError(
|
||||
"`use_delta_joint_actions_aloha` is used by smolvla for aloha real models. It is not ported yet in LeRobot."
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
for i in range(self.empty_cameras):
|
||||
key = f"observation.images.empty_camera_{i}"
|
||||
empty_camera = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, 480, 640),
|
||||
)
|
||||
self.input_features[key] = empty_camera
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
return [0]
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
801
lerobot/common/policies/smolvla/modeling_smolvla.py
Normal file
801
lerobot/common/policies/smolvla/modeling_smolvla.py
Normal file
@@ -0,0 +1,801 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
SmolVLA:
|
||||
|
||||
[Paper](https://huggingface.co/papers/2506.01844)
|
||||
|
||||
Designed by Hugging Face.
|
||||
|
||||
Install smolvla extra dependencies:
|
||||
```bash
|
||||
pip install -e ".[smolvla]"
|
||||
```
|
||||
|
||||
Example of finetuning the smolvla pretrained model (`smolvla_base`):
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
|
||||
--batch_size=64 \
|
||||
--steps=200000
|
||||
```
|
||||
|
||||
Example of finetuning a smolVLA. SmolVLA is composed of a pretrained VLM,
|
||||
and an action expert.
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type=smolvla \
|
||||
--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \
|
||||
--batch_size=64 \
|
||||
--steps=200000
|
||||
```
|
||||
|
||||
Example of using the smolvla pretrained model outside LeRobot training framework:
|
||||
```python
|
||||
policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
import math
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from lerobot.common.constants import ACTION, OBS_ROBOT
|
||||
from lerobot.common.policies.normalize import (
|
||||
Normalize,
|
||||
Unnormalize,
|
||||
)
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.common.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
|
||||
from lerobot.common.policies.utils import (
|
||||
populate_queues,
|
||||
)
|
||||
from lerobot.common.utils.utils import get_safe_dtype
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
) -> Tensor:
|
||||
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
||||
if dimension % 2 != 0:
|
||||
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
|
||||
|
||||
if time.ndim != 1:
|
||||
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
|
||||
|
||||
dtype = get_safe_dtype(torch.float64, device.type)
|
||||
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
|
||||
period = min_period * (max_period / min_period) ** fraction
|
||||
|
||||
# Compute the outer product
|
||||
scaling_factor = 1.0 / period * 2 * math.pi
|
||||
sin_input = scaling_factor[None, :] * time[:, None]
|
||||
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
|
||||
return pos_emb
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device):
|
||||
gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha)
|
||||
gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta)
|
||||
return gamma1 / (gamma1 + gamma2)
|
||||
|
||||
|
||||
def make_att_2d_masks(pad_masks, att_masks):
|
||||
"""Copied from big_vision.
|
||||
|
||||
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
|
||||
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
|
||||
setup several types of attention, for example:
|
||||
|
||||
[[1 1 1 1 1 1]]: pure causal attention.
|
||||
|
||||
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
|
||||
themselves and the last 3 tokens have a causal attention. The first
|
||||
entry could also be a 1 without changing behaviour.
|
||||
|
||||
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
|
||||
block can attend all previous blocks and all tokens on the same block.
|
||||
|
||||
Args:
|
||||
input_mask: bool[B, N] true if its part of the input, false if padding.
|
||||
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
|
||||
it and 0 where it shares the same attention mask as the previous token.
|
||||
"""
|
||||
if att_masks.ndim != 2:
|
||||
raise ValueError(att_masks.ndim)
|
||||
if pad_masks.ndim != 2:
|
||||
raise ValueError(pad_masks.ndim)
|
||||
|
||||
cumsum = torch.cumsum(att_masks, dim=1)
|
||||
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
|
||||
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
|
||||
att_2d_masks = att_2d_masks & pad_2d_masks
|
||||
return att_2d_masks
|
||||
|
||||
|
||||
def resize_with_pad(img, width, height, pad_value=-1):
|
||||
# assume no-op when width height fits already
|
||||
if img.ndim != 4:
|
||||
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
|
||||
|
||||
cur_height, cur_width = img.shape[2:]
|
||||
|
||||
ratio = max(cur_width / width, cur_height / height)
|
||||
resized_height = int(cur_height / ratio)
|
||||
resized_width = int(cur_width / ratio)
|
||||
resized_img = F.interpolate(
|
||||
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
||||
)
|
||||
|
||||
pad_height = max(0, int(height - resized_height))
|
||||
pad_width = max(0, int(width - resized_width))
|
||||
|
||||
# pad on left and top of image
|
||||
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
||||
return padded_img
|
||||
|
||||
|
||||
def pad_vector(vector, new_dim):
|
||||
"""Can be (batch_size x sequence_length x features_dimension)
|
||||
or (batch_size x features_dimension)
|
||||
"""
|
||||
if vector.shape[-1] == new_dim:
|
||||
return vector
|
||||
shape = list(vector.shape)
|
||||
current_dim = shape[-1]
|
||||
shape[-1] = new_dim
|
||||
new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
|
||||
new_vector[..., :current_dim] = vector
|
||||
return new_vector
|
||||
|
||||
|
||||
def normalize(x, min_val, max_val):
|
||||
return (x - min_val) / (max_val - min_val)
|
||||
|
||||
|
||||
def unnormalize(x, min_val, max_val):
|
||||
return x * (max_val - min_val) + min_val
|
||||
|
||||
|
||||
def safe_arcsin(value):
|
||||
# This ensures that the input stays within
|
||||
# [−1,1] to avoid invalid values for arcsin
|
||||
return torch.arcsin(torch.clamp(value, -1.0, 1.0))
|
||||
|
||||
|
||||
def aloha_gripper_to_angular(value):
|
||||
# Aloha transforms the gripper positions into a linear space. The following code
|
||||
# reverses this transformation to be consistent with smolvla which is pretrained in
|
||||
# angular space.
|
||||
#
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
|
||||
value = unnormalize(value, min_val=0.01844, max_val=0.05800)
|
||||
|
||||
# This is the inverse of the angular to linear transformation inside the Interbotix code.
|
||||
def linear_to_radian(linear_position, arm_length, horn_radius):
|
||||
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
|
||||
return safe_arcsin(value)
|
||||
|
||||
# The constants are taken from the Interbotix code.
|
||||
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
|
||||
|
||||
# Normalize to [0, 1].
|
||||
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||||
return normalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
|
||||
def aloha_gripper_from_angular(value):
|
||||
# Convert from the gripper position used by smolvla to the gripper position that is used by Aloha.
|
||||
# Note that the units are still angular but the range is different.
|
||||
|
||||
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||||
value = unnormalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
|
||||
return normalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
|
||||
|
||||
def aloha_gripper_from_angular_inv(value):
|
||||
# Directly inverts the gripper_from_angular function.
|
||||
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
return normalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
|
||||
class SmolVLAPolicy(PreTrainedPolicy):
|
||||
"""Wrapper class around VLAFlowMatching model to train and run inference within LeRobot."""
|
||||
|
||||
config_class = SmolVLAConfig
|
||||
name = "smolvla"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SmolVLAConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer
|
||||
self.model = VLAFlowMatching(config)
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
self._queues = {
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
queue is empty.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
# querying the policy.
|
||||
if len(self._queues[ACTION]) == 0:
|
||||
for k in batch:
|
||||
if k in self._queues:
|
||||
batch[k] = torch.stack(list(self._queues[k]), dim=1)
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
|
||||
actions = self.model.sample_actions(
|
||||
images, img_masks, lang_tokens, lang_masks, state, noise=noise
|
||||
)
|
||||
# Unpad actions
|
||||
original_action_dim = self.config.action_feature.shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
actions = self._pi_aloha_encode_actions(actions)
|
||||
|
||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
|
||||
return self._queues[ACTION].popleft()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
|
||||
"""Do a full training forward pass to compute the loss"""
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
actions = self.prepare_action(batch)
|
||||
actions_is_pad = batch.get("actions_id_pad")
|
||||
loss_dict = {}
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
||||
loss_dict["losses_after_forward"] = losses.clone()
|
||||
|
||||
if actions_is_pad is not None:
|
||||
in_episode_bound = ~actions_is_pad
|
||||
losses = losses * in_episode_bound.unsqueeze(-1)
|
||||
loss_dict["losses_after_in_ep_bound"] = losses.clone()
|
||||
|
||||
# Remove padding
|
||||
losses = losses[:, :, : self.config.max_action_dim]
|
||||
loss_dict["losses_after_rm_padding"] = losses.clone()
|
||||
|
||||
# For backward pass
|
||||
loss = losses.mean()
|
||||
# For backward pass
|
||||
loss_dict["loss"] = loss
|
||||
return loss, loss_dict
|
||||
|
||||
def prepare_images(self, batch):
|
||||
"""Apply SmolVLA preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
|
||||
convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP.
|
||||
"""
|
||||
images = []
|
||||
img_masks = []
|
||||
present_img_keys = [key for key in self.config.image_features if key in batch]
|
||||
missing_img_keys = [key for key in self.config.image_features if key not in batch]
|
||||
|
||||
if len(present_img_keys) == 0:
|
||||
raise ValueError(
|
||||
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
|
||||
)
|
||||
# Preprocess image features present in the batch
|
||||
for key in present_img_keys:
|
||||
img = batch[key][:, -1, :, :, :] if batch[key].ndim == 5 else batch[key]
|
||||
if self.config.resize_imgs_with_padding is not None:
|
||||
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
|
||||
|
||||
# Normalize from range [0,1] to [-1,1] as expacted by siglip
|
||||
img = img * 2.0 - 1.0
|
||||
|
||||
bsize = img.shape[0]
|
||||
device = img.device
|
||||
if f"{key}_padding_mask" in batch:
|
||||
mask = batch[f"{key}_padding_mask"].bool()
|
||||
else:
|
||||
mask = torch.ones(bsize, dtype=torch.bool, device=device)
|
||||
images.append(img)
|
||||
img_masks.append(mask)
|
||||
|
||||
# Create image features not present in the batch
|
||||
# as fully 0 padded images.
|
||||
for num_empty_cameras in range(len(missing_img_keys)):
|
||||
if num_empty_cameras >= self.config.empty_cameras:
|
||||
break
|
||||
img = torch.ones_like(img) * -1
|
||||
mask = torch.zeros_like(mask)
|
||||
images.append(img)
|
||||
img_masks.append(mask)
|
||||
return images, img_masks
|
||||
|
||||
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
|
||||
"""Tokenize the text input"""
|
||||
device = batch[OBS_ROBOT].device
|
||||
tasks = batch["task"]
|
||||
if len(tasks) == 1:
|
||||
tasks = [tasks[0] for _ in range(batch[OBS_ROBOT].shape[0])]
|
||||
|
||||
tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks]
|
||||
tokenized_prompt = self.language_tokenizer.__call__(
|
||||
tasks,
|
||||
padding=self.config.pad_language_to,
|
||||
padding_side="right",
|
||||
max_length=self.config.tokenizer_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
lang_tokens = tokenized_prompt["input_ids"].to(device=device)
|
||||
lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool)
|
||||
|
||||
return lang_tokens, lang_masks
|
||||
|
||||
def _pi_aloha_decode_state(self, state):
|
||||
# Flip the joints.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
state[:, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
|
||||
return state
|
||||
|
||||
def _pi_aloha_encode_actions(self, actions):
|
||||
# Flip the joints.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
actions[:, :, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
|
||||
return actions
|
||||
|
||||
def _pi_aloha_encode_actions_inv(self, actions):
|
||||
# Flip the joints again.
|
||||
for motor_idx in [1, 2, 8, 9]:
|
||||
actions[:, :, motor_idx] *= -1
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
for motor_idx in [6, 13]:
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
|
||||
return actions
|
||||
|
||||
def prepare_state(self, batch):
|
||||
"""Pad state"""
|
||||
state = batch[OBS_ROBOT][:, -1, :] if batch[OBS_ROBOT].ndim > 2 else batch[OBS_ROBOT]
|
||||
state = pad_vector(state, self.config.max_state_dim)
|
||||
return state
|
||||
|
||||
def prepare_action(self, batch):
|
||||
"""Pad action"""
|
||||
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
|
||||
return actions
|
||||
|
||||
|
||||
def pad_tensor(tensor, max_len, pad_value=0):
|
||||
"""
|
||||
Efficiently pads a tensor along sequence dimension to match max_len.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): Shape (B, L, ...) or (B, L).
|
||||
max_len (int): Fixed sequence length.
|
||||
pad_value (int/float): Value for padding.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Shape (B, max_len, ...) or (B, max_len).
|
||||
"""
|
||||
b, d = tensor.shape[:2]
|
||||
|
||||
# Create a padded tensor of max_len and copy the existing values
|
||||
padded_tensor = torch.full(
|
||||
(b, max_len, *tensor.shape[2:]), pad_value, dtype=tensor.dtype, device=tensor.device
|
||||
)
|
||||
padded_tensor[:, :d] = tensor # Efficient in-place copy
|
||||
|
||||
return padded_tensor
|
||||
|
||||
|
||||
class VLAFlowMatching(nn.Module):
|
||||
"""
|
||||
SmolVLA
|
||||
|
||||
[Paper]()
|
||||
|
||||
Designed by Hugging Face.
|
||||
┌──────────────────────────────┐
|
||||
│ actions │
|
||||
│ ▲ │
|
||||
│ ┌─────────┐ ┌─|────┐ │
|
||||
│ | │────► │ │ │
|
||||
│ | │ kv │ │ │
|
||||
│ | │────► │Action│ │
|
||||
│ | VLM │cache │Expert│ |
|
||||
│ │ │────► | │ │
|
||||
│ │ │ │ │ │
|
||||
│ └▲──▲───▲─┘ └───▲──┘ |
|
||||
│ │ | | │ |
|
||||
│ | | | noise │
|
||||
│ │ │ state │
|
||||
│ │ language tokens │
|
||||
│ image(s) │
|
||||
└──────────────────────────────┘
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.vlm_with_expert = SmolVLMWithExpertModel(
|
||||
model_id=self.config.vlm_model_name,
|
||||
freeze_vision_encoder=self.config.freeze_vision_encoder,
|
||||
train_expert_only=self.config.train_expert_only,
|
||||
load_vlm_weights=self.config.load_vlm_weights,
|
||||
attention_mode=self.config.attention_mode,
|
||||
num_expert_layers=self.config.num_expert_layers,
|
||||
num_vlm_layers=self.config.num_vlm_layers,
|
||||
self_attn_every_n_layers=self.config.self_attn_every_n_layers,
|
||||
expert_width_multiplier=self.config.expert_width_multiplier,
|
||||
)
|
||||
self.state_proj = nn.Linear(
|
||||
self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size
|
||||
)
|
||||
self.action_in_proj = nn.Linear(self.config.max_action_dim, self.vlm_with_expert.expert_hidden_size)
|
||||
self.action_out_proj = nn.Linear(self.vlm_with_expert.expert_hidden_size, self.config.max_action_dim)
|
||||
|
||||
self.action_time_mlp_in = nn.Linear(
|
||||
self.vlm_with_expert.expert_hidden_size * 2, self.vlm_with_expert.expert_hidden_size
|
||||
)
|
||||
self.action_time_mlp_out = nn.Linear(
|
||||
self.vlm_with_expert.expert_hidden_size, self.vlm_with_expert.expert_hidden_size
|
||||
)
|
||||
|
||||
self.set_requires_grad()
|
||||
self.fake_image_token = self.vlm_with_expert.processor.tokenizer.fake_image_token_id
|
||||
self.global_image_token = self.vlm_with_expert.processor.tokenizer.global_image_token_id
|
||||
self.global_image_start_token = torch.tensor(
|
||||
[self.fake_image_token, self.global_image_token], dtype=torch.long
|
||||
)
|
||||
|
||||
self.add_image_special_tokens = self.config.add_image_special_tokens
|
||||
self.image_end_token = torch.tensor([self.fake_image_token], dtype=torch.long)
|
||||
self.prefix_length = self.config.prefix_length
|
||||
|
||||
def set_requires_grad(self):
|
||||
for params in self.state_proj.parameters():
|
||||
params.requires_grad = self.config.train_state_proj
|
||||
|
||||
def sample_noise(self, shape, device):
|
||||
noise = torch.normal(
|
||||
mean=0.0,
|
||||
std=1.0,
|
||||
size=shape,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
return noise
|
||||
|
||||
def sample_time(self, bsize, device):
|
||||
time_beta = sample_beta(1.5, 1.0, bsize, device)
|
||||
time = time_beta * 0.999 + 0.001
|
||||
return time.to(dtype=torch.float32, device=device)
|
||||
|
||||
def embed_prefix(
|
||||
self, images, img_masks, lang_tokens, lang_masks, state: torch.Tensor = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Embed images with SigLIP and language tokens with embedding layer to prepare
|
||||
for SmolVLM transformer processing.
|
||||
"""
|
||||
embs = []
|
||||
pad_masks = []
|
||||
att_masks = []
|
||||
for _img_idx, (
|
||||
img,
|
||||
img_mask,
|
||||
) in enumerate(zip(images, img_masks, strict=False)):
|
||||
if self.add_image_special_tokens:
|
||||
image_start_token = (
|
||||
self.vlm_with_expert.embed_language_tokens(
|
||||
self.global_image_start_token.to(device=self.vlm_with_expert.vlm.device)
|
||||
)
|
||||
.unsqueeze(0)
|
||||
.expand(img.shape[0], -1, -1)
|
||||
)
|
||||
image_start_mask = torch.ones_like(
|
||||
image_start_token[:, :, 0], dtype=torch.bool, device=image_start_token.device
|
||||
)
|
||||
att_masks += [0] * (image_start_mask.shape[-1])
|
||||
embs.append(image_start_token)
|
||||
pad_masks.append(image_start_mask)
|
||||
|
||||
img_emb = self.vlm_with_expert.embed_image(img)
|
||||
img_emb = img_emb
|
||||
|
||||
# Normalize image embeddings
|
||||
img_emb_dim = img_emb.shape[-1]
|
||||
img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device)
|
||||
|
||||
bsize, num_img_embs = img_emb.shape[:2]
|
||||
img_mask = img_mask[:, None].expand(bsize, num_img_embs)
|
||||
|
||||
embs.append(img_emb)
|
||||
pad_masks.append(img_mask)
|
||||
|
||||
att_masks += [0] * (num_img_embs)
|
||||
if self.add_image_special_tokens:
|
||||
image_end_token = (
|
||||
self.vlm_with_expert.embed_language_tokens(
|
||||
self.image_end_token.to(device=self.vlm_with_expert.vlm.device)
|
||||
)
|
||||
.unsqueeze(0)
|
||||
.expand(img.shape[0], -1, -1)
|
||||
)
|
||||
image_end_mask = torch.ones_like(
|
||||
image_end_token[:, :, 0], dtype=torch.bool, device=image_end_token.device
|
||||
)
|
||||
embs.append(image_end_token)
|
||||
pad_masks.append(image_end_mask)
|
||||
att_masks += [0] * (image_end_mask.shape[1])
|
||||
lang_emb = self.vlm_with_expert.embed_language_tokens(lang_tokens)
|
||||
# Normalize language embeddings
|
||||
lang_emb_dim = lang_emb.shape[-1]
|
||||
lang_emb = lang_emb * math.sqrt(lang_emb_dim)
|
||||
|
||||
embs.append(lang_emb)
|
||||
pad_masks.append(lang_masks)
|
||||
|
||||
num_lang_embs = lang_emb.shape[1]
|
||||
att_masks += [0] * num_lang_embs
|
||||
|
||||
state_emb = self.state_proj(state)
|
||||
state_emb = state_emb[:, None, :] if state_emb.ndim == 2 else state_emb
|
||||
embs.append(state_emb)
|
||||
bsize = state_emb.shape[0]
|
||||
device = state_emb.device
|
||||
|
||||
states_seq_len = state_emb.shape[1]
|
||||
state_mask = torch.ones(bsize, states_seq_len, dtype=torch.bool, device=device)
|
||||
pad_masks.append(state_mask)
|
||||
|
||||
# Set attention masks so that image and language inputs do not attend to state or actions
|
||||
att_masks += [1] * (states_seq_len)
|
||||
embs = torch.cat(embs, dim=1)
|
||||
pad_masks = torch.cat(pad_masks, dim=1)
|
||||
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
|
||||
att_masks = att_masks[None, :]
|
||||
|
||||
seq_len = pad_masks.shape[1]
|
||||
if seq_len < self.prefix_length:
|
||||
embs = pad_tensor(embs, self.prefix_length, pad_value=0)
|
||||
pad_masks = pad_tensor(pad_masks, self.prefix_length, pad_value=0)
|
||||
att_masks = pad_tensor(att_masks, self.prefix_length, pad_value=0)
|
||||
|
||||
att_masks = att_masks.expand(bsize, -1)
|
||||
|
||||
return embs, pad_masks, att_masks
|
||||
|
||||
def embed_suffix(self, noisy_actions, timestep):
|
||||
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
|
||||
embs = []
|
||||
pad_masks = []
|
||||
att_masks = []
|
||||
|
||||
# Fuse timestep + action information using an MLP
|
||||
action_emb = self.action_in_proj(noisy_actions)
|
||||
device = action_emb.device
|
||||
bsize = action_emb.shape[0]
|
||||
dtype = action_emb.dtype
|
||||
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
|
||||
time_emb = create_sinusoidal_pos_embedding(
|
||||
timestep,
|
||||
self.vlm_with_expert.expert_hidden_size,
|
||||
self.config.min_period,
|
||||
self.config.max_period,
|
||||
device=device,
|
||||
)
|
||||
time_emb = time_emb.type(dtype=dtype)
|
||||
|
||||
time_emb = time_emb[:, None, :].expand_as(action_emb)
|
||||
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
|
||||
|
||||
action_time_emb = self.action_time_mlp_in(action_time_emb)
|
||||
action_time_emb = F.silu(action_time_emb) # swish == silu
|
||||
action_time_emb = self.action_time_mlp_out(action_time_emb)
|
||||
|
||||
# Add to input tokens
|
||||
embs.append(action_time_emb)
|
||||
|
||||
bsize, action_time_dim = action_time_emb.shape[:2]
|
||||
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device)
|
||||
pad_masks.append(action_time_mask)
|
||||
|
||||
# Set attention masks so that image, language and state inputs do not attend to action tokens
|
||||
att_masks += [1] * self.config.chunk_size
|
||||
embs = torch.cat(embs, dim=1)
|
||||
pad_masks = torch.cat(pad_masks, dim=1)
|
||||
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
|
||||
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
|
||||
return embs, pad_masks, att_masks
|
||||
|
||||
def forward(
|
||||
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
|
||||
) -> Tensor:
|
||||
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
|
||||
if noise is None:
|
||||
noise = self.sample_noise(actions.shape, actions.device)
|
||||
|
||||
if time is None:
|
||||
time = self.sample_time(actions.shape[0], actions.device)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks, state=state
|
||||
)
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(x_t, time)
|
||||
|
||||
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
||||
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
||||
|
||||
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
||||
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||
(_, suffix_out), _ = self.vlm_with_expert.forward(
|
||||
attention_mask=att_2d_masks,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, suffix_embs],
|
||||
use_cache=False,
|
||||
fill_kv_cache=False,
|
||||
)
|
||||
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
||||
# Original openpi code, upcast attention output
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
v_t = self.action_out_proj(suffix_out)
|
||||
losses = F.mse_loss(u_t, v_t, reduction="none")
|
||||
return losses
|
||||
|
||||
def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
|
||||
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
|
||||
bsize = state.shape[0]
|
||||
device = state.device
|
||||
|
||||
if noise is None:
|
||||
actions_shape = (bsize, self.config.chunk_size, self.config.max_action_dim)
|
||||
noise = self.sample_noise(actions_shape, device)
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks, state=state
|
||||
)
|
||||
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
|
||||
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
|
||||
# Compute image and language key value cache
|
||||
_, past_key_values = self.vlm_with_expert.forward(
|
||||
attention_mask=prefix_att_2d_masks,
|
||||
position_ids=prefix_position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[prefix_embs, None],
|
||||
use_cache=self.config.use_cache,
|
||||
fill_kv_cache=True,
|
||||
)
|
||||
dt = -1.0 / self.config.num_steps
|
||||
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
||||
|
||||
x_t = noise
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
expanded_time,
|
||||
)
|
||||
# Euler step
|
||||
x_t += dt * v_t
|
||||
time += dt
|
||||
return x_t
|
||||
|
||||
def denoise_step(
|
||||
self,
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
timestep,
|
||||
):
|
||||
"""Apply one denoising step of the noise `x_t` at a given timestep."""
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(x_t, timestep)
|
||||
|
||||
suffix_len = suffix_pad_masks.shape[1]
|
||||
batch_size = prefix_pad_masks.shape[0]
|
||||
prefix_len = prefix_pad_masks.shape[1]
|
||||
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
|
||||
|
||||
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
|
||||
|
||||
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
|
||||
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
|
||||
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
|
||||
|
||||
outputs_embeds, _ = self.vlm_with_expert.forward(
|
||||
attention_mask=full_att_2d_masks,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=[None, suffix_embs],
|
||||
use_cache=self.config.use_cache,
|
||||
fill_kv_cache=False,
|
||||
)
|
||||
suffix_out = outputs_embeds[1]
|
||||
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
v_t = self.action_out_proj(suffix_out)
|
||||
return v_t
|
||||
550
lerobot/common/policies/smolvla/smolvlm_with_expert.py
Normal file
550
lerobot/common/policies/smolvla/smolvlm_with_expert.py
Normal file
@@ -0,0 +1,550 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForImageTextToText,
|
||||
AutoProcessor,
|
||||
SmolVLMForConditionalGeneration,
|
||||
)
|
||||
|
||||
|
||||
def apply_rope(x, positions, max_wavelength=10_000):
|
||||
"""
|
||||
Applies RoPE positions [B, L] to x [B, L, H, D].
|
||||
"""
|
||||
d_half = x.shape[-1] // 2
|
||||
device = x.device
|
||||
dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
|
||||
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
|
||||
timescale = max_wavelength**freq_exponents
|
||||
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
|
||||
|
||||
radians = radians[..., None, :]
|
||||
|
||||
sin = torch.sin(radians) # .to(dtype=dtype)
|
||||
cos = torch.cos(radians) # .to(dtype=dtype)
|
||||
|
||||
x1, x2 = x.split(d_half, dim=-1)
|
||||
res = torch.empty_like(x)
|
||||
res[..., :d_half] = x1 * cos - x2 * sin
|
||||
res[..., d_half:] = x2 * cos + x1 * sin
|
||||
|
||||
return res.to(dtype)
|
||||
|
||||
|
||||
def get_intermediate_size(hidden_dim, ffn_dim_multiplier=4, multiple_of=256):
|
||||
hidden_dim = int(2 * hidden_dim / 3)
|
||||
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
return hidden_dim
|
||||
|
||||
|
||||
class SmolVLMWithExpertModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct",
|
||||
load_vlm_weights: bool = True,
|
||||
train_expert_only: bool = True,
|
||||
freeze_vision_encoder: bool = False,
|
||||
attention_mode: str = "self_attn",
|
||||
num_expert_layers: int = -1,
|
||||
num_vlm_layers: int = -1,
|
||||
self_attn_every_n_layers: int = -1,
|
||||
expert_width_multiplier: float = 0.5,
|
||||
):
|
||||
super().__init__()
|
||||
if load_vlm_weights:
|
||||
print(f"Loading {model_id} weights ...")
|
||||
self.vlm = AutoModelForImageTextToText.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
torch_dtype="bfloat16",
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
config = self.vlm.config
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(model_id)
|
||||
self.vlm = SmolVLMForConditionalGeneration(config=config)
|
||||
self.processor = AutoProcessor.from_pretrained(model_id)
|
||||
if num_vlm_layers > 0:
|
||||
print(f"Reducing the number of VLM layers to {num_vlm_layers} ...")
|
||||
self.get_vlm_model().text_model.layers = self.get_vlm_model().text_model.layers[:num_vlm_layers]
|
||||
self.num_vlm_layers = len(self.get_vlm_model().text_model.layers)
|
||||
self.config = config
|
||||
# Smaller lm expert
|
||||
lm_expert_config = copy.deepcopy(config.text_config)
|
||||
hidden_size = lm_expert_config.hidden_size
|
||||
lm_expert_config.hidden_size = int(hidden_size * expert_width_multiplier) # hidden_size // 2
|
||||
lm_expert_config.intermediate_size = get_intermediate_size(int(hidden_size * expert_width_multiplier))
|
||||
lm_expert_config.num_hidden_layers = self.num_vlm_layers
|
||||
if num_expert_layers > 0:
|
||||
assert len(self.get_vlm_model().text_model.layers) % num_expert_layers == 0, (
|
||||
f"Number of layers in the VLM {len(self.get_vlm_model().text_model.layers)} are not multiple of num_expert_layers {num_expert_layers}"
|
||||
)
|
||||
lm_expert_config.num_hidden_layers = num_expert_layers
|
||||
self.lm_expert = AutoModel.from_config(lm_expert_config)
|
||||
|
||||
self.num_expert_layers = len(self.lm_expert.layers)
|
||||
self.self_attn_every_n_layers = self_attn_every_n_layers
|
||||
if "cross" in attention_mode:
|
||||
# Reshape qkv projections to have the same input dimension as the vlm
|
||||
for layer_idx in range(len(self.lm_expert.layers)):
|
||||
if self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0:
|
||||
continue
|
||||
self.lm_expert.layers[layer_idx].self_attn.k_proj = nn.Linear(
|
||||
config.text_config.num_key_value_heads * config.text_config.head_dim,
|
||||
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
|
||||
bias=lm_expert_config.attention_bias,
|
||||
)
|
||||
self.lm_expert.layers[layer_idx].self_attn.v_proj = nn.Linear(
|
||||
config.text_config.num_key_value_heads * config.text_config.head_dim,
|
||||
lm_expert_config.num_key_value_heads * lm_expert_config.head_dim,
|
||||
bias=lm_expert_config.attention_bias,
|
||||
)
|
||||
# Remove unused embed_tokens
|
||||
self.lm_expert.embed_tokens = None
|
||||
|
||||
self.num_attention_heads = self.config.text_config.num_attention_heads
|
||||
self.num_key_value_heads = self.config.text_config.num_key_value_heads
|
||||
|
||||
self.freeze_vision_encoder = freeze_vision_encoder
|
||||
self.train_expert_only = train_expert_only
|
||||
self.attention_mode = attention_mode
|
||||
self.expert_hidden_size = lm_expert_config.hidden_size
|
||||
self.set_requires_grad()
|
||||
|
||||
def get_vlm_model(self):
|
||||
return self.vlm.model
|
||||
|
||||
def set_requires_grad(self):
|
||||
if self.freeze_vision_encoder:
|
||||
self.get_vlm_model().vision_model.eval()
|
||||
for params in self.get_vlm_model().vision_model.parameters():
|
||||
params.requires_grad = False
|
||||
if self.train_expert_only:
|
||||
self.vlm.eval()
|
||||
for params in self.vlm.parameters():
|
||||
params.requires_grad = False
|
||||
else:
|
||||
# To avoid unused params issue with distributed training
|
||||
last_layers = [self.num_vlm_layers - 1]
|
||||
if (
|
||||
self.num_vlm_layers != self.num_expert_layers
|
||||
and self.num_vlm_layers % self.num_expert_layers == 0
|
||||
):
|
||||
last_layers.append(self.num_vlm_layers - 2)
|
||||
frozen_layers = [
|
||||
"lm_head",
|
||||
"text_model.model.norm.weight",
|
||||
]
|
||||
for layer in last_layers:
|
||||
frozen_layers.append(f"text_model.model.layers.{layer}.")
|
||||
|
||||
for name, params in self.vlm.named_parameters():
|
||||
if any(k in name for k in frozen_layers):
|
||||
params.requires_grad = False
|
||||
# To avoid unused params issue with distributed training
|
||||
for name, params in self.lm_expert.named_parameters():
|
||||
if "lm_head" in name:
|
||||
params.requires_grad = False
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
super().train(mode)
|
||||
|
||||
if self.freeze_vision_encoder:
|
||||
self.get_vlm_model().vision_model.eval()
|
||||
|
||||
if self.train_expert_only:
|
||||
self.vlm.eval()
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
patch_attention_mask = None
|
||||
# Get sequence from the vision encoder
|
||||
image_hidden_states = (
|
||||
self.get_vlm_model()
|
||||
.vision_model(
|
||||
pixel_values=image.to(dtype=self.get_vlm_model().vision_model.dtype),
|
||||
patch_attention_mask=patch_attention_mask,
|
||||
)
|
||||
.last_hidden_state
|
||||
)
|
||||
# Modality projection & resampling
|
||||
image_hidden_states = self.get_vlm_model().connector(image_hidden_states)
|
||||
return image_hidden_states
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.get_vlm_model().text_model.get_input_embeddings()(tokens)
|
||||
|
||||
def forward_attn_layer(
|
||||
self,
|
||||
model_layers,
|
||||
inputs_embeds,
|
||||
layer_idx,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
use_cache: bool = True,
|
||||
fill_kv_cache: bool = True,
|
||||
past_key_values=None,
|
||||
) -> list[torch.Tensor]:
|
||||
query_states = []
|
||||
key_states = []
|
||||
value_states = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = model_layers[i][layer_idx]
|
||||
if hidden_states is None or layer is None:
|
||||
continue
|
||||
hidden_states = layer.input_layernorm(hidden_states)
|
||||
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
|
||||
hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype)
|
||||
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
|
||||
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
|
||||
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
|
||||
|
||||
query_states.append(query_state)
|
||||
key_states.append(key_state)
|
||||
value_states.append(value_state)
|
||||
|
||||
# B,L,H,D with L sequence length, H number of heads, D head dim
|
||||
# concatenate on the number of embeddings/tokens
|
||||
query_states = torch.cat(query_states, dim=1)
|
||||
key_states = torch.cat(key_states, dim=1)
|
||||
value_states = torch.cat(value_states, dim=1)
|
||||
seq_len = query_states.shape[1]
|
||||
if seq_len < position_ids.shape[1]:
|
||||
_position_ids = position_ids[:, :seq_len]
|
||||
_attention_mask = attention_mask[:, :seq_len, :seq_len]
|
||||
else:
|
||||
_position_ids = position_ids
|
||||
_attention_mask = attention_mask
|
||||
|
||||
attention_mask_ = _attention_mask
|
||||
position_ids_ = _position_ids
|
||||
|
||||
query_states = apply_rope(query_states, position_ids_)
|
||||
key_states = apply_rope(key_states, position_ids_)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = {}
|
||||
|
||||
if use_cache:
|
||||
if fill_kv_cache:
|
||||
past_key_values[layer_idx] = {
|
||||
"key_states": key_states,
|
||||
"value_states": value_states,
|
||||
}
|
||||
else:
|
||||
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
|
||||
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
|
||||
# the max len, then we (for instance) double the cache size. This implementation already exists
|
||||
# in `transformers`. (molbap)
|
||||
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
|
||||
value_states = torch.cat([past_key_values[layer_idx]["value_states"], value_states], dim=1)
|
||||
|
||||
attention_interface = self.get_attention_interface()
|
||||
|
||||
att_output = attention_interface(
|
||||
attention_mask_, batch_size, head_dim, query_states, key_states, value_states
|
||||
)
|
||||
return [att_output], past_key_values
|
||||
|
||||
def forward_cross_attn_layer(
|
||||
self,
|
||||
model_layers,
|
||||
inputs_embeds,
|
||||
layer_idx,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
use_cache: bool = True,
|
||||
fill_kv_cache: bool = True,
|
||||
past_key_values=None,
|
||||
) -> list[torch.Tensor]:
|
||||
attention_interface = self.get_attention_interface()
|
||||
|
||||
att_outputs = []
|
||||
assert len(inputs_embeds) == 2 or (use_cache and past_key_values is not None and not fill_kv_cache), (
|
||||
f"Both len(inputs_embeds) == {len(inputs_embeds)} and past_key_values is {past_key_values}"
|
||||
)
|
||||
|
||||
if len(inputs_embeds) == 2 and not past_key_values:
|
||||
# Prefix attention
|
||||
seq_len = inputs_embeds[0].shape[1]
|
||||
position_id, expert_position_id = position_ids[:, :seq_len], position_ids[:, seq_len:]
|
||||
prefix_attention_mask = attention_mask[:, :seq_len, :seq_len]
|
||||
|
||||
layer = model_layers[0][layer_idx]
|
||||
|
||||
hidden_states = layer.input_layernorm(inputs_embeds[0])
|
||||
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
|
||||
hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype)
|
||||
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
|
||||
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
|
||||
value_states = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
|
||||
|
||||
# B,L,H,D with L sequence length, H number of heads, D head dim
|
||||
query_states = apply_rope(query_state, position_id)
|
||||
key_states = apply_rope(key_state, position_id)
|
||||
|
||||
att_output = attention_interface(
|
||||
prefix_attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
)
|
||||
att_outputs.append(att_output)
|
||||
else:
|
||||
expert_position_id = position_ids
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = {}
|
||||
|
||||
if use_cache:
|
||||
if fill_kv_cache:
|
||||
past_key_values[layer_idx] = {
|
||||
"key_states": key_states,
|
||||
"value_states": value_states,
|
||||
}
|
||||
else:
|
||||
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
|
||||
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
|
||||
# the max len, then we (for instance) double the cache size. This implementation already exists
|
||||
# in `transformers`. (molbap)
|
||||
key_states = past_key_values[layer_idx]["key_states"]
|
||||
value_states = past_key_values[layer_idx]["value_states"]
|
||||
|
||||
# Expert
|
||||
expert_layer = model_layers[1][layer_idx]
|
||||
if expert_layer is not None:
|
||||
expert_hidden_states = expert_layer.input_layernorm(inputs_embeds[1])
|
||||
|
||||
expert_input_shape = expert_hidden_states.shape[:-1]
|
||||
expert_hidden_shape = (*expert_input_shape, -1, expert_layer.self_attn.head_dim)
|
||||
|
||||
expert_hidden_states = expert_hidden_states.to(dtype=expert_layer.self_attn.q_proj.weight.dtype)
|
||||
expert_query_state = expert_layer.self_attn.q_proj(expert_hidden_states).view(expert_hidden_shape)
|
||||
|
||||
_key_states = key_states.to(dtype=expert_layer.self_attn.k_proj.weight.dtype).view(
|
||||
*key_states.shape[:2], -1
|
||||
)
|
||||
expert_key_states = expert_layer.self_attn.k_proj(_key_states).view(
|
||||
*_key_states.shape[:-1], -1, expert_layer.self_attn.head_dim
|
||||
) # k_proj should have same dim as kv
|
||||
|
||||
_value_states = value_states.to(dtype=expert_layer.self_attn.v_proj.weight.dtype).view(
|
||||
*value_states.shape[:2], -1
|
||||
)
|
||||
expert_value_states = expert_layer.self_attn.v_proj(_value_states).view(
|
||||
*_value_states.shape[:-1], -1, expert_layer.self_attn.head_dim
|
||||
)
|
||||
|
||||
expert_position_id = (
|
||||
expert_position_id - torch.min(expert_position_id, dim=1, keepdim=True).values
|
||||
) # start from 0
|
||||
expert_attention_mask = attention_mask[
|
||||
:, -inputs_embeds[1].shape[1] :, : expert_key_states.shape[1] :
|
||||
] # take into account kv
|
||||
|
||||
expert_query_states = apply_rope(expert_query_state, expert_position_id)
|
||||
|
||||
att_output = attention_interface(
|
||||
expert_attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
expert_query_states,
|
||||
expert_key_states,
|
||||
expert_value_states,
|
||||
)
|
||||
att_outputs.append(att_output)
|
||||
else:
|
||||
att_outputs.append(None)
|
||||
|
||||
# att_output = att_output.to(dtype=models[i].dtype)
|
||||
return att_outputs, past_key_values
|
||||
|
||||
def get_model_layers(self, models: list) -> list:
|
||||
vlm_layers = []
|
||||
expert_layers = []
|
||||
multiple_of = self.num_vlm_layers // self.num_expert_layers
|
||||
for i in range(self.num_vlm_layers):
|
||||
if multiple_of > 0 and i > 0 and i % multiple_of != 0:
|
||||
expert_layer = None
|
||||
else:
|
||||
expert_layer_index = i // multiple_of if multiple_of > 0 else i
|
||||
expert_layer = models[1].layers[expert_layer_index]
|
||||
vlm_layers.append(models[0].layers[i])
|
||||
expert_layers.append(expert_layer)
|
||||
return [vlm_layers, expert_layers]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: List[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
fill_kv_cache: Optional[bool] = None,
|
||||
):
|
||||
models = [self.get_vlm_model().text_model, self.lm_expert]
|
||||
model_layers = self.get_model_layers(models)
|
||||
for hidden_states in inputs_embeds:
|
||||
# TODO this is very inefficient
|
||||
# dtype is always the same, batch size too (if > 1 len)
|
||||
# device could be trickier in multi gpu edge cases but that's it
|
||||
if hidden_states is None:
|
||||
continue
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
# RMSNorm
|
||||
num_layers = self.num_vlm_layers
|
||||
head_dim = self.vlm.config.text_config.head_dim
|
||||
for layer_idx in range(num_layers):
|
||||
if (
|
||||
fill_kv_cache
|
||||
or "cross" not in self.attention_mode
|
||||
or (self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0)
|
||||
):
|
||||
att_outputs, past_key_values = self.forward_attn_layer(
|
||||
model_layers,
|
||||
inputs_embeds,
|
||||
layer_idx,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
use_cache=use_cache,
|
||||
fill_kv_cache=fill_kv_cache,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
else:
|
||||
att_outputs, past_key_values = self.forward_cross_attn_layer(
|
||||
model_layers,
|
||||
inputs_embeds,
|
||||
layer_idx,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
use_cache=use_cache,
|
||||
fill_kv_cache=fill_kv_cache,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
outputs_embeds = []
|
||||
start = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = model_layers[i][layer_idx]
|
||||
att_output = (
|
||||
att_outputs[i] if i < len(att_outputs) else att_outputs[0]
|
||||
) # in case of self_attn
|
||||
if hidden_states is not None:
|
||||
if layer is None:
|
||||
outputs_embeds.append(hidden_states)
|
||||
continue
|
||||
end = start + hidden_states.shape[1]
|
||||
|
||||
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
att_out = att_output[:, start:end]
|
||||
out_emb = layer.self_attn.o_proj(att_out)
|
||||
|
||||
out_emb += hidden_states
|
||||
after_first_residual = out_emb.clone()
|
||||
|
||||
out_emb = layer.post_attention_layernorm(out_emb)
|
||||
out_emb = layer.mlp(out_emb)
|
||||
|
||||
out_emb += after_first_residual
|
||||
|
||||
outputs_embeds.append(out_emb)
|
||||
|
||||
start = end if len(att_outputs) == 1 else 0
|
||||
else:
|
||||
outputs_embeds.append(None)
|
||||
|
||||
inputs_embeds = outputs_embeds
|
||||
|
||||
# final norm
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
if hidden_states is not None:
|
||||
out_emb = models[i].norm(hidden_states)
|
||||
outputs_embeds.append(out_emb)
|
||||
else:
|
||||
outputs_embeds.append(None)
|
||||
return outputs_embeds, past_key_values
|
||||
|
||||
def get_attention_interface(self):
|
||||
attention_interface = self.eager_attention_forward
|
||||
return attention_interface
|
||||
|
||||
def eager_attention_forward(
|
||||
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
):
|
||||
num_att_heads = self.num_attention_heads
|
||||
num_key_value_heads = self.num_key_value_heads
|
||||
num_key_value_groups = num_att_heads // num_key_value_heads
|
||||
|
||||
sequence_length = key_states.shape[1]
|
||||
|
||||
key_states = key_states[:, :, :, None, :].expand(
|
||||
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
key_states = key_states.reshape(
|
||||
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
value_states = value_states[:, :, :, None, :].expand(
|
||||
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
value_states = value_states.reshape(
|
||||
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
# Attention here is upcasted to float32 to match the original eager implementation.
|
||||
query_states = query_states.to(dtype=torch.float32)
|
||||
key_states = key_states.to(dtype=torch.float32)
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
|
||||
att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
||||
att_weights *= head_dim**-0.5
|
||||
|
||||
att_weights = att_weights.to(dtype=torch.float32)
|
||||
big_neg = torch.finfo(att_weights.dtype).min # -2.3819763e38 # See gemma/modules.py
|
||||
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
|
||||
probs = nn.functional.softmax(masked_att_weights, dim=-1)
|
||||
probs = probs.to(dtype=value_states.dtype)
|
||||
|
||||
att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
|
||||
|
||||
att_output = att_output.permute(0, 2, 1, 3)
|
||||
# we use -1 because sequence length can change
|
||||
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
|
||||
|
||||
return att_output
|
||||
@@ -58,7 +58,7 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f
|
||||
log_dt("dt", dt_s)
|
||||
|
||||
# TODO(aliberts): move robot-specific logs logic in robot.print_logs()
|
||||
if not robot.robot_type.startswith("stretch"):
|
||||
if not robot.robot_type.startswith(("stretch", "realman")):
|
||||
for name in robot.leader_arms:
|
||||
key = f"read_leader_{name}_pos_dt_s"
|
||||
if key in robot.logs:
|
||||
@@ -109,6 +109,10 @@ def predict_action(observation, policy, device, use_amp):
|
||||
):
|
||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||
for name in observation:
|
||||
# Skip all observations that are not tensors (e.g. text)
|
||||
if not isinstance(observation[name], torch.Tensor):
|
||||
continue
|
||||
|
||||
if "image" in name:
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||
@@ -243,6 +247,11 @@ def control_loop(
|
||||
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
# Controls starts, if policy is given it needs cleaning up
|
||||
if policy is not None:
|
||||
policy.reset()
|
||||
|
||||
while timestamp < control_time_s:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
@@ -251,7 +260,8 @@ def control_loop(
|
||||
else:
|
||||
observation = robot.capture_observation()
|
||||
action = None
|
||||
|
||||
observation["task"] = [single_task]
|
||||
observation["robot_type"] = [policy.robot_type] if hasattr(policy, "robot_type") else [""]
|
||||
if policy is not None:
|
||||
pred_action = predict_action(
|
||||
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
|
||||
@@ -262,6 +272,7 @@ def control_loop(
|
||||
action = {"action": action}
|
||||
|
||||
if dataset is not None:
|
||||
observation = {k: v for k, v in observation.items() if k not in ["task", "robot_type"]}
|
||||
frame = {**observation, **action, "task": single_task}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
|
||||
@@ -39,3 +39,12 @@ class FeetechMotorsBusConfig(MotorsBusConfig):
|
||||
port: str
|
||||
motors: dict[str, tuple[int, str]]
|
||||
mock: bool = False
|
||||
|
||||
|
||||
@MotorsBusConfig.register_subclass("realman")
|
||||
@dataclass
|
||||
class RealmanMotorsBusConfig(MotorsBusConfig):
|
||||
ip: str
|
||||
port: int
|
||||
motors: dict[str, tuple[int, str]]
|
||||
init_joint: dict[str, list]
|
||||
150
lerobot/common/robot_devices/motors/realman.py
Normal file
150
lerobot/common/robot_devices/motors/realman.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import time
|
||||
from typing import Dict
|
||||
from lerobot.common.robot_devices.motors.configs import RealmanMotorsBusConfig
|
||||
from Robotic_Arm.rm_robot_interface import *
|
||||
|
||||
|
||||
class RealmanMotorsBus:
|
||||
"""
|
||||
对Realman SDK的二次封装
|
||||
"""
|
||||
def __init__(self,
|
||||
config: RealmanMotorsBusConfig):
|
||||
self.rmarm = RoboticArm(rm_thread_mode_e.RM_TRIPLE_MODE_E)
|
||||
self.handle = self.rmarm.rm_create_robot_arm(config.ip, config.port)
|
||||
self.motors = config.motors
|
||||
self.init_joint_position = config.init_joint['joint'] # [6 joints + 1 gripper]
|
||||
self.safe_disable_position = config.init_joint['joint']
|
||||
self.rmarm.rm_movej(self.init_joint_position[:-1], 5, 0, 0, 1)
|
||||
time.sleep(3)
|
||||
ret = self.rmarm.rm_get_current_arm_state()
|
||||
self.init_pose = ret[1]['pose']
|
||||
|
||||
@property
|
||||
def motor_names(self) -> list[str]:
|
||||
return list(self.motors.keys())
|
||||
|
||||
@property
|
||||
def motor_models(self) -> list[str]:
|
||||
return [model for _, model in self.motors.values()]
|
||||
|
||||
@property
|
||||
def motor_indices(self) -> list[int]:
|
||||
return [idx for idx, _ in self.motors.values()]
|
||||
|
||||
|
||||
def connect(self, enable=True) -> bool:
|
||||
'''
|
||||
使能机械臂并检测使能状态,尝试5s,如果使能超时则退出程序
|
||||
'''
|
||||
enable_flag = False
|
||||
loop_flag = False
|
||||
# 设置超时时间(秒)
|
||||
timeout = 5
|
||||
# 记录进入循环前的时间
|
||||
start_time = time.time()
|
||||
elapsed_time_flag = False
|
||||
|
||||
while not loop_flag:
|
||||
elapsed_time = time.time() - start_time
|
||||
print("--------------------")
|
||||
|
||||
if enable:
|
||||
# 获取机械臂状态
|
||||
ret = self.rmarm.rm_get_current_arm_state()
|
||||
if ret[0] == 0: # 成功获取状态
|
||||
enable_flag = True
|
||||
else:
|
||||
enable_flag = False
|
||||
# 断开所有连接,销毁线程
|
||||
RoboticArm.rm_destory()
|
||||
print("使能状态:", enable_flag)
|
||||
print("--------------------")
|
||||
if(enable_flag == enable):
|
||||
loop_flag = True
|
||||
enable_flag = True
|
||||
else:
|
||||
loop_flag = False
|
||||
enable_flag = False
|
||||
# 检查是否超过超时时间
|
||||
if elapsed_time > timeout:
|
||||
print("超时....")
|
||||
elapsed_time_flag = True
|
||||
enable_flag = True
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
resp = enable_flag
|
||||
print(f"Returning response: {resp}")
|
||||
return resp
|
||||
|
||||
def motor_names(self):
|
||||
return
|
||||
|
||||
def set_calibration(self):
|
||||
return
|
||||
|
||||
def revert_calibration(self):
|
||||
return
|
||||
|
||||
def apply_calibration(self):
|
||||
"""
|
||||
移动到初始位置
|
||||
"""
|
||||
self.write(target_joint=self.init_joint_position)
|
||||
|
||||
def write(self, target_joint:list):
|
||||
# self.rmarm.rm_movej(target_joint[:-1], 50, 0, 0, 1)
|
||||
self.rmarm.rm_movej_follow(target_joint[:-1])
|
||||
self.rmarm.rm_set_gripper_position(target_joint[-1], block=False, timeout=2)
|
||||
|
||||
def write_endpose(self, target_endpose: list, gripper: int):
|
||||
self.rmarm.rm_movej_p(target_endpose, 50, 0, 0, 1)
|
||||
self.rmarm.rm_set_gripper_position(gripper, block=False, timeout=2)
|
||||
|
||||
def write_joint_slow(self, target_joint: list):
|
||||
self.rmarm.rm_movej(target_joint, 5, 0, 0, 0)
|
||||
|
||||
def write_joint_canfd(self, target_joint: list):
|
||||
self.rmarm.rm_movej_canfd(target_joint, False)
|
||||
|
||||
def write_endpose_canfd(self, target_pose: list):
|
||||
self.rmarm.rm_movep_canfd(target_pose, False)
|
||||
|
||||
def write_gripper(self, gripper: int):
|
||||
self.rmarm.rm_set_gripper_position(gripper, False, 2)
|
||||
|
||||
def read(self) -> Dict:
|
||||
"""
|
||||
- 机械臂关节消息,单位1度;[-1, 1]
|
||||
- 机械臂夹爪消息,[-1, 1]
|
||||
"""
|
||||
joint_msg = self.rmarm.rm_get_current_arm_state()[1]
|
||||
joint_state = joint_msg['joint']
|
||||
|
||||
gripper_msg = self.rmarm.rm_get_gripper_state()[1]
|
||||
gripper_state = gripper_msg['actpos']
|
||||
|
||||
return {
|
||||
"joint_1": joint_state[0]/180,
|
||||
"joint_2": joint_state[1]/180,
|
||||
"joint_3": joint_state[2]/180,
|
||||
"joint_4": joint_state[3]/180,
|
||||
"joint_5": joint_state[4]/180,
|
||||
"joint_6": joint_state[5]/180,
|
||||
"gripper": (gripper_state-500)/500
|
||||
}
|
||||
|
||||
def read_current_arm_joint_state(self):
|
||||
return self.rmarm.rm_get_current_arm_state()[1]['joint']
|
||||
|
||||
def read_current_arm_endpose_state(self):
|
||||
return self.rmarm.rm_get_current_arm_state()[1]['pose']
|
||||
|
||||
def safe_disconnect(self):
|
||||
"""
|
||||
Move to safe disconnect position
|
||||
"""
|
||||
self.write(target_joint=self.safe_disable_position)
|
||||
# 断开所有连接,销毁线程
|
||||
RoboticArm.rm_destory()
|
||||
@@ -44,6 +44,11 @@ def make_motors_buses_from_configs(motors_bus_configs: dict[str, MotorsBusConfig
|
||||
|
||||
motors_buses[key] = FeetechMotorsBus(cfg)
|
||||
|
||||
elif cfg.type == "realman":
|
||||
from lerobot.common.robot_devices.motors.realman import RealmanMotorsBus
|
||||
|
||||
motors_buses[key] = RealmanMotorsBus(cfg)
|
||||
|
||||
else:
|
||||
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
|
||||
|
||||
@@ -65,3 +70,7 @@ def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus:
|
||||
|
||||
else:
|
||||
raise ValueError(f"The motor type '{motor_type}' is not valid.")
|
||||
|
||||
|
||||
def get_motor_names(arm: dict[str, MotorsBus]) -> list:
|
||||
return [f"{arm}_{motor}" for arm, bus in arm.items() for motor in bus.motors]
|
||||
@@ -27,6 +27,7 @@ from lerobot.common.robot_devices.motors.configs import (
|
||||
DynamixelMotorsBusConfig,
|
||||
FeetechMotorsBusConfig,
|
||||
MotorsBusConfig,
|
||||
RealmanMotorsBusConfig
|
||||
)
|
||||
|
||||
|
||||
@@ -674,3 +675,91 @@ class LeKiwiRobotConfig(RobotConfig):
|
||||
)
|
||||
|
||||
mock: bool = False
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("realman")
|
||||
@dataclass
|
||||
class RealmanRobotConfig(RobotConfig):
|
||||
inference_time: bool = False
|
||||
max_gripper: int = 990
|
||||
min_gripper: int = 10
|
||||
servo_config_file: str = "/home/maic/LYT/lerobot/lerobot/common/robot_devices/teleop/servo_arm.yaml"
|
||||
|
||||
|
||||
left_follower_arm: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": RealmanMotorsBusConfig(
|
||||
ip = "192.168.3.18",
|
||||
port = 8080,
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"joint_1": [1, "realman"],
|
||||
"joint_2": [2, "realman"],
|
||||
"joint_3": [3, "realman"],
|
||||
"joint_4": [4, "realman"],
|
||||
"joint_5": [5, "realman"],
|
||||
"joint_6": [6, "realman"],
|
||||
"gripper": [7, "realman"],
|
||||
},
|
||||
init_joint = {'joint': [-90, 90, 90, 90, 90, -90, 10]}
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(
|
||||
default_factory=lambda: {
|
||||
# "one": OpenCVCameraConfig(
|
||||
# camera_index=4,
|
||||
# fps=30,
|
||||
# width=640,
|
||||
# height=480,
|
||||
# ),
|
||||
"left": IntelRealSenseCameraConfig(
|
||||
serial_number="153122077516",
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
use_depth=False
|
||||
),
|
||||
# "right": IntelRealSenseCameraConfig(
|
||||
# serial_number="405622075165",
|
||||
# fps=30,
|
||||
# width=640,
|
||||
# height=480,
|
||||
# use_depth=False
|
||||
# ),
|
||||
"front": IntelRealSenseCameraConfig(
|
||||
serial_number="145422072751",
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
use_depth=False
|
||||
),
|
||||
"high": IntelRealSenseCameraConfig(
|
||||
serial_number="145422072193",
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
use_depth=False
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# right_follower_arm: dict[str, MotorsBusConfig] = field(
|
||||
# default_factory=lambda: {
|
||||
# "main": RealmanMotorsBusConfig(
|
||||
# ip = "192.168.3.19",
|
||||
# port = 8080,
|
||||
# motors={
|
||||
# # name: (index, model)
|
||||
# "joint_1": [1, "realman"],
|
||||
# "joint_2": [2, "realman"],
|
||||
# "joint_3": [3, "realman"],
|
||||
# "joint_4": [4, "realman"],
|
||||
# "joint_5": [5, "realman"],
|
||||
# "joint_6": [6, "realman"],
|
||||
# "gripper": (7, "realman"),
|
||||
# },
|
||||
# )
|
||||
# }
|
||||
# )
|
||||
|
||||
292
lerobot/common/robot_devices/robots/realman.py
Normal file
292
lerobot/common/robot_devices/robots/realman.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""
|
||||
Teleoperation Realman with a PS5 controller and
|
||||
"""
|
||||
|
||||
import time
|
||||
import torch
|
||||
import numpy as np
|
||||
from dataclasses import dataclass, field, replace
|
||||
from collections import deque
|
||||
from lerobot.common.robot_devices.teleop.gamepad import HybridController
|
||||
from lerobot.common.robot_devices.motors.utils import get_motor_names, make_motors_buses_from_configs
|
||||
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
from lerobot.common.robot_devices.robots.configs import RealmanRobotConfig
|
||||
|
||||
|
||||
class RealmanRobot:
|
||||
def __init__(self, config: RealmanRobotConfig | None = None, **kwargs):
|
||||
if config is None:
|
||||
config = RealmanRobotConfig()
|
||||
# Overwrite config arguments using kwargs
|
||||
self.config = replace(config, **kwargs)
|
||||
self.robot_type = self.config.type
|
||||
self.inference_time = self.config.inference_time # if it is inference time
|
||||
|
||||
# build cameras
|
||||
self.cameras = make_cameras_from_configs(self.config.cameras)
|
||||
|
||||
# build realman motors
|
||||
self.piper_motors = make_motors_buses_from_configs(self.config.left_follower_arm)
|
||||
self.arm = self.piper_motors['main']
|
||||
|
||||
# build init teleop info
|
||||
self.init_info = {
|
||||
'init_joint': self.arm.init_joint_position,
|
||||
'init_pose': self.arm.init_pose,
|
||||
'max_gripper': config.max_gripper,
|
||||
'min_gripper': config.min_gripper,
|
||||
'servo_config_file': config.servo_config_file
|
||||
}
|
||||
|
||||
# build state-action cache
|
||||
self.joint_queue = deque(maxlen=2)
|
||||
self.last_endpose = self.arm.init_pose
|
||||
|
||||
# build gamepad teleop
|
||||
if not self.inference_time:
|
||||
self.teleop = HybridController(self.init_info)
|
||||
else:
|
||||
self.teleop = None
|
||||
|
||||
self.logs = {}
|
||||
self.is_connected = False
|
||||
|
||||
@property
|
||||
def camera_features(self) -> dict:
|
||||
cam_ft = {}
|
||||
for cam_key, cam in self.cameras.items():
|
||||
key = f"observation.images.{cam_key}"
|
||||
cam_ft[key] = {
|
||||
"shape": (cam.height, cam.width, cam.channels),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
}
|
||||
return cam_ft
|
||||
|
||||
|
||||
@property
|
||||
def motor_features(self) -> dict:
|
||||
action_names = get_motor_names(self.piper_motors)
|
||||
state_names = get_motor_names(self.piper_motors)
|
||||
return {
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(action_names),),
|
||||
"names": action_names,
|
||||
},
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(state_names),),
|
||||
"names": state_names,
|
||||
},
|
||||
}
|
||||
|
||||
@property
|
||||
def has_camera(self):
|
||||
return len(self.cameras) > 0
|
||||
|
||||
@property
|
||||
def num_cameras(self):
|
||||
return len(self.cameras)
|
||||
|
||||
|
||||
def connect(self) -> None:
|
||||
"""Connect RealmanArm and cameras"""
|
||||
if self.is_connected:
|
||||
raise RobotDeviceAlreadyConnectedError(
|
||||
"RealmanArm is already connected. Do not run `robot.connect()` twice."
|
||||
)
|
||||
|
||||
# connect RealmanArm
|
||||
self.arm.connect(enable=True)
|
||||
print("RealmanArm conneted")
|
||||
|
||||
# connect cameras
|
||||
for name in self.cameras:
|
||||
self.cameras[name].connect()
|
||||
self.is_connected = self.is_connected and self.cameras[name].is_connected
|
||||
print(f"camera {name} conneted")
|
||||
|
||||
print("All connected")
|
||||
self.is_connected = True
|
||||
|
||||
self.run_calibration()
|
||||
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""move to home position, disenable piper and cameras"""
|
||||
# move piper to home position, disable
|
||||
if not self.inference_time:
|
||||
self.teleop.stop()
|
||||
|
||||
# disconnect piper
|
||||
self.arm.safe_disconnect()
|
||||
print("RealmanArm disable after 5 seconds")
|
||||
time.sleep(5)
|
||||
self.arm.connect(enable=False)
|
||||
|
||||
# disconnect cameras
|
||||
if len(self.cameras) > 0:
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
self.is_connected = False
|
||||
|
||||
|
||||
def run_calibration(self):
|
||||
"""move piper to the home position"""
|
||||
if not self.is_connected:
|
||||
raise ConnectionError()
|
||||
|
||||
self.arm.apply_calibration()
|
||||
if not self.inference_time:
|
||||
self.teleop.reset()
|
||||
|
||||
|
||||
def teleop_step(
|
||||
self, record_data=False
|
||||
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
||||
if not self.is_connected:
|
||||
raise ConnectionError()
|
||||
|
||||
if self.teleop is None and self.inference_time:
|
||||
self.teleop = HybridController(self.init_info)
|
||||
|
||||
# read target pose state as
|
||||
before_read_t = time.perf_counter()
|
||||
state = self.arm.read() # read current joint position from robot
|
||||
action = self.teleop.get_action() # target joint position and pose end pos from gamepad
|
||||
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
||||
|
||||
if action['control_mode'] == 'joint':
|
||||
# 关节控制模式(主模式)
|
||||
current_pose = self.arm.read_current_arm_endpose_state()
|
||||
self.teleop.update_endpose_state(current_pose)
|
||||
|
||||
target_joints = action['joint_angles'][:-1]
|
||||
self.arm.write_gripper(action['gripper'])
|
||||
print(action['gripper'])
|
||||
if action['master_controller_status']['infrared'] == 1:
|
||||
if action['master_controller_status']['button'] == 1:
|
||||
self.arm.write_joint_canfd(target_joints)
|
||||
else:
|
||||
self.arm.write_joint_slow(target_joints)
|
||||
|
||||
# do action
|
||||
before_write_t = time.perf_counter()
|
||||
self.joint_queue.append(list(self.arm.read().values()))
|
||||
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
|
||||
|
||||
else:
|
||||
target_pose = list(action['end_pose'])
|
||||
# do action
|
||||
before_write_t = time.perf_counter()
|
||||
if self.last_endpose != target_pose:
|
||||
self.arm.write_endpose_canfd(target_pose)
|
||||
self.last_endpose = target_pose
|
||||
self.arm.write_gripper(action['gripper'])
|
||||
|
||||
target_joints = self.arm.read_current_arm_joint_state()
|
||||
self.joint_queue.append(list(self.arm.read().values()))
|
||||
self.teleop.update_joint_state(target_joints)
|
||||
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
|
||||
|
||||
if not record_data:
|
||||
return
|
||||
|
||||
state = torch.as_tensor(list(self.joint_queue[0]), dtype=torch.float32)
|
||||
action = torch.as_tensor(list(self.joint_queue[-1]), dtype=torch.float32)
|
||||
|
||||
# Capture images from cameras
|
||||
images = {}
|
||||
for name in self.cameras:
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
images[name] = torch.from_numpy(images[name])
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
|
||||
# Populate output dictionnaries
|
||||
obs_dict, action_dict = {}, {}
|
||||
obs_dict["observation.state"] = state
|
||||
action_dict["action"] = action
|
||||
for name in self.cameras:
|
||||
obs_dict[f"observation.images.{name}"] = images[name]
|
||||
|
||||
return obs_dict, action_dict
|
||||
|
||||
|
||||
|
||||
def send_action(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""Write the predicted actions from policy to the motors"""
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
"Piper is not connected. You need to run `robot.connect()`."
|
||||
)
|
||||
|
||||
# send to motors, torch to list
|
||||
target_joints = action.tolist()
|
||||
len_joint = len(target_joints) - 1
|
||||
target_joints = [target_joints[i]*180 for i in range(len_joint)] + [target_joints[-1]]
|
||||
target_joints[-1] = int(target_joints[-1]*500 + 500)
|
||||
self.arm.write(target_joints)
|
||||
|
||||
return action
|
||||
|
||||
|
||||
|
||||
def capture_observation(self) -> dict:
|
||||
"""capture current images and joint positions"""
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
"Piper is not connected. You need to run `robot.connect()`."
|
||||
)
|
||||
|
||||
# read current joint positions
|
||||
before_read_t = time.perf_counter()
|
||||
state = self.arm.read() # 6 joints + 1 gripper
|
||||
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
||||
|
||||
state = torch.as_tensor(list(state.values()), dtype=torch.float32)
|
||||
|
||||
# read images from cameras
|
||||
images = {}
|
||||
for name in self.cameras:
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
images[name] = torch.from_numpy(images[name])
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
|
||||
# Populate output dictionnaries and format to pytorch
|
||||
obs_dict = {}
|
||||
obs_dict["observation.state"] = state
|
||||
for name in self.cameras:
|
||||
obs_dict[f"observation.images.{name}"] = images[name]
|
||||
return obs_dict
|
||||
|
||||
def teleop_safety_stop(self):
|
||||
""" move to home position after record one episode """
|
||||
self.run_calibration()
|
||||
|
||||
|
||||
def __del__(self):
|
||||
if self.is_connected:
|
||||
self.disconnect()
|
||||
if not self.inference_time:
|
||||
self.teleop.stop()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
robot = RealmanRobot()
|
||||
robot.connect()
|
||||
# robot.run_calibration()
|
||||
while True:
|
||||
robot.teleop_step(record_data=True)
|
||||
|
||||
robot.capture_observation()
|
||||
dummy_action = torch.Tensor([-0.40586111280653214, 0.5522833506266276, 0.4998166826036241, -0.3539944542778863, -0.524433347913954, 0.9064999898274739, 0.482])
|
||||
robot.send_action(dummy_action)
|
||||
robot.disconnect()
|
||||
print('ok')
|
||||
@@ -25,6 +25,7 @@ from lerobot.common.robot_devices.robots.configs import (
|
||||
So100RobotConfig,
|
||||
So101RobotConfig,
|
||||
StretchRobotConfig,
|
||||
RealmanRobotConfig
|
||||
)
|
||||
|
||||
|
||||
@@ -65,6 +66,9 @@ def make_robot_config(robot_type: str, **kwargs) -> RobotConfig:
|
||||
return StretchRobotConfig(**kwargs)
|
||||
elif robot_type == "lekiwi":
|
||||
return LeKiwiRobotConfig(**kwargs)
|
||||
elif robot_type == 'realman':
|
||||
return RealmanRobotConfig(**kwargs)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Robot type '{robot_type}' is not available.")
|
||||
|
||||
@@ -78,6 +82,12 @@ def make_robot_from_config(config: RobotConfig):
|
||||
from lerobot.common.robot_devices.robots.mobile_manipulator import MobileManipulator
|
||||
|
||||
return MobileManipulator(config)
|
||||
|
||||
elif isinstance(config, RealmanRobotConfig):
|
||||
from lerobot.common.robot_devices.robots.realman import RealmanRobot
|
||||
|
||||
return RealmanRobot(config)
|
||||
|
||||
else:
|
||||
from lerobot.common.robot_devices.robots.stretch import StretchRobot
|
||||
|
||||
|
||||
466
lerobot/common/robot_devices/teleop/gamepad.py
Normal file
466
lerobot/common/robot_devices/teleop/gamepad.py
Normal file
@@ -0,0 +1,466 @@
|
||||
import pygame
|
||||
import threading
|
||||
import time
|
||||
import serial
|
||||
import binascii
|
||||
import logging
|
||||
import yaml
|
||||
from typing import Dict
|
||||
from Robotic_Arm.rm_robot_interface import *
|
||||
|
||||
|
||||
|
||||
class ServoArm:
|
||||
def __init__(self, config_file="config.yaml"):
|
||||
"""初始化机械臂的串口连接并发送初始数据。
|
||||
|
||||
Args:
|
||||
config_file (str): 配置文件的路径。
|
||||
"""
|
||||
self.config = self._load_config(config_file)
|
||||
self.port = self.config["port"]
|
||||
self.baudrate = self.config["baudrate"]
|
||||
self.joint_hex_data = self.config["joint_hex_data"]
|
||||
self.control_hex_data = self.config["control_hex_data"]
|
||||
self.arm_axis = self.config.get("arm_axis", 7)
|
||||
|
||||
try:
|
||||
self.serial_conn = serial.Serial(self.port, self.baudrate, timeout=0)
|
||||
self.bytes_to_send = binascii.unhexlify(self.joint_hex_data.replace(" ", ""))
|
||||
self.serial_conn.write(self.bytes_to_send)
|
||||
time.sleep(1)
|
||||
self.connected = True
|
||||
logging.info(f"串口连接成功: {self.port}")
|
||||
except Exception as e:
|
||||
logging.error(f"串口连接失败: {e}")
|
||||
self.connected = False
|
||||
|
||||
def _load_config(self, config_file):
|
||||
"""加载配置文件。
|
||||
|
||||
Args:
|
||||
config_file (str): 配置文件的路径。
|
||||
|
||||
Returns:
|
||||
dict: 配置文件内容。
|
||||
"""
|
||||
try:
|
||||
with open(config_file, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
return config
|
||||
except Exception as e:
|
||||
logging.error(f"配置文件加载失败: {e}")
|
||||
# 返回默认配置
|
||||
return {
|
||||
"port": "/dev/ttyUSB0",
|
||||
"baudrate": 460800,
|
||||
"joint_hex_data": "55 AA 02 00 00 67",
|
||||
"control_hex_data": "55 AA 08 00 00 B9",
|
||||
"arm_axis": 6
|
||||
}
|
||||
|
||||
def _bytes_to_signed_int(self, byte_data):
|
||||
"""将字节数据转换为有符号整数。
|
||||
|
||||
Args:
|
||||
byte_data (bytes): 字节数据。
|
||||
|
||||
Returns:
|
||||
int: 有符号整数。
|
||||
"""
|
||||
return int.from_bytes(byte_data, byteorder="little", signed=True)
|
||||
|
||||
def _parse_joint_data(self, hex_received):
|
||||
"""解析接收到的十六进制数据并提取关节数据。
|
||||
|
||||
Args:
|
||||
hex_received (str): 接收到的十六进制字符串数据。
|
||||
|
||||
Returns:
|
||||
dict: 解析后的关节数据。
|
||||
"""
|
||||
logging.debug(f"hex_received: {hex_received}")
|
||||
joints = {}
|
||||
for i in range(self.arm_axis):
|
||||
start = 14 + i * 10
|
||||
end = start + 8
|
||||
joint_hex = hex_received[start:end]
|
||||
joint_byte_data = bytearray.fromhex(joint_hex)
|
||||
joint_value = self._bytes_to_signed_int(joint_byte_data) / 10000.0
|
||||
joints[f"joint_{i+1}"] = joint_value
|
||||
grasp_start = 14 + self.arm_axis*10
|
||||
grasp_hex = hex_received[grasp_start:grasp_start+8]
|
||||
grasp_byte_data = bytearray.fromhex(grasp_hex)
|
||||
# 夹爪进行归一化处理
|
||||
grasp_value = self._bytes_to_signed_int(grasp_byte_data)/1000
|
||||
|
||||
joints["grasp"] = grasp_value
|
||||
return joints
|
||||
|
||||
def _parse_controller_data(self, hex_received):
|
||||
status = {
|
||||
'infrared': 0,
|
||||
'button': 0
|
||||
}
|
||||
if len(hex_received) == 18:
|
||||
status['infrared'] = self._bytes_to_signed_int(bytearray.fromhex(hex_received[12:14]))
|
||||
status['button'] = self._bytes_to_signed_int(bytearray.fromhex(hex_received[14:16]))
|
||||
# print(infrared)
|
||||
return status
|
||||
|
||||
def get_joint_actions(self):
|
||||
"""从串口读取数据并解析关节动作。
|
||||
|
||||
Returns:
|
||||
dict: 包含关节数据的字典。
|
||||
"""
|
||||
if not self.connected:
|
||||
return {}
|
||||
|
||||
try:
|
||||
self.serial_conn.write(self.bytes_to_send)
|
||||
time.sleep(0.02)
|
||||
bytes_received = self.serial_conn.read(self.serial_conn.inWaiting())
|
||||
if len(bytes_received) == 0:
|
||||
return {}
|
||||
|
||||
hex_received = binascii.hexlify(bytes_received).decode("utf-8").upper()
|
||||
actions = self._parse_joint_data(hex_received)
|
||||
return actions
|
||||
except Exception as e:
|
||||
logging.error(f"读取串口数据错误: {e}")
|
||||
return {}
|
||||
|
||||
def get_controller_status(self):
|
||||
bytes_to_send = binascii.unhexlify(self.control_hex_data.replace(" ", ""))
|
||||
self.serial_conn.write(bytes_to_send)
|
||||
time.sleep(0.02)
|
||||
bytes_received = self.serial_conn.read(self.serial_conn.inWaiting())
|
||||
hex_received = binascii.hexlify(bytes_received).decode("utf-8").upper()
|
||||
# print("control status:", hex_received)
|
||||
status = self._parse_controller_data(hex_received)
|
||||
return status
|
||||
|
||||
def close(self):
|
||||
"""关闭串口连接"""
|
||||
if self.connected and hasattr(self, 'serial_conn'):
|
||||
self.serial_conn.close()
|
||||
self.connected = False
|
||||
logging.info("串口连接已关闭")
|
||||
|
||||
|
||||
class HybridController:
|
||||
def __init__(self, init_info):
|
||||
# 初始化pygame和手柄
|
||||
pygame.init()
|
||||
pygame.joystick.init()
|
||||
|
||||
# 检查是否有连接的手柄
|
||||
if pygame.joystick.get_count() == 0:
|
||||
raise Exception("未检测到手柄")
|
||||
|
||||
# 初始化手柄
|
||||
self.joystick = pygame.joystick.Joystick(0)
|
||||
self.joystick.init()
|
||||
# 摇杆死区
|
||||
self.deadzone = 0.15
|
||||
# 控制模式: True为关节控制(主模式),False为末端控制
|
||||
self.joint_control_mode = True
|
||||
# 精细控制模式
|
||||
self.fine_control_mode = False
|
||||
|
||||
# 初始化末端姿态和关节角度
|
||||
self.init_joint = init_info['init_joint']
|
||||
self.init_pose = init_info.get('init_pose', [0]*6)
|
||||
self.max_gripper = init_info['max_gripper']
|
||||
self.min_gripper = init_info['min_gripper']
|
||||
servo_config_file = init_info['servo_config_file']
|
||||
self.joint = self.init_joint.copy()
|
||||
self.pose = self.init_pose.copy()
|
||||
self.pose_speeds = [0.0] * 6
|
||||
self.joint_speeds = [0.0] * 6
|
||||
self.tozero = False
|
||||
|
||||
# 主臂关节状态
|
||||
self.master_joint_actions = {}
|
||||
self.master_controller_status = {}
|
||||
self.use_master_arm = False
|
||||
|
||||
# 末端位姿限制
|
||||
self.pose_limits = [
|
||||
(-0.800, 0.800), # X (m)
|
||||
(-0.800, 0.800), # Y (m)
|
||||
(-0.800, 0.800), # Z (m)
|
||||
(-3.14, 3.14), # RX (rad)
|
||||
(-3.14, 3.14), # RY (rad)
|
||||
(-3.14, 3.14) # RZ (rad)
|
||||
]
|
||||
|
||||
# 关节角度限制 (度)
|
||||
self.joint_limits = [
|
||||
(-180, 180), # joint 1
|
||||
(-180, 180), # joint 2
|
||||
(-180, 180), # joint 3
|
||||
(-180, 180), # joint 4
|
||||
(-180, 180), # joint 5
|
||||
(-180, 180) # joint 6
|
||||
]
|
||||
|
||||
# 控制参数
|
||||
self.linear_step = 0.002 # 线性移动步长(m)
|
||||
self.angular_step = 0.01 # 角度步长(rad)
|
||||
|
||||
# 夹爪状态和速度
|
||||
self.gripper_speed = 10
|
||||
self.gripper = self.min_gripper
|
||||
|
||||
# 初始化串口通信(主臂关节状态获取)
|
||||
self.servo_arm = None
|
||||
if servo_config_file:
|
||||
try:
|
||||
self.servo_arm = ServoArm(servo_config_file)
|
||||
self.use_master_arm = True
|
||||
logging.info("串口主臂连接成功,启用主从控制模式")
|
||||
except Exception as e:
|
||||
logging.error(f"串口主臂连接失败: {e}")
|
||||
self.use_master_arm = False
|
||||
|
||||
# 启动更新线程
|
||||
self.running = True
|
||||
self.thread = threading.Thread(target=self.update_controller)
|
||||
self.thread.start()
|
||||
|
||||
print("混合控制器已启动")
|
||||
print("主控制模式: 关节控制")
|
||||
if self.use_master_arm:
|
||||
print("主从控制: 启用")
|
||||
print("Back按钮: 切换控制模式(关节/末端)")
|
||||
print("L3按钮: 切换精细控制模式")
|
||||
print("Start按钮: 重置到初始位置")
|
||||
|
||||
def _apply_nonlinear_mapping(self, value):
|
||||
"""应用非线性映射以提高控制精度"""
|
||||
sign = 1 if value >= 0 else -1
|
||||
return sign * (abs(value) ** 2)
|
||||
|
||||
def _normalize_angle(self, angle):
|
||||
"""将角度归一化到[-π, π]范围内"""
|
||||
import math
|
||||
while angle > math.pi:
|
||||
angle -= 2 * math.pi
|
||||
while angle < -math.pi:
|
||||
angle += 2 * math.pi
|
||||
return angle
|
||||
|
||||
def update_controller(self):
|
||||
while self.running:
|
||||
try:
|
||||
pygame.event.pump()
|
||||
except Exception as e:
|
||||
print(f"控制器错误: {e}")
|
||||
self.stop()
|
||||
continue
|
||||
|
||||
# 检查控制模式切换 (Back按钮)
|
||||
if self.joystick.get_button(6): # Back按钮
|
||||
self.joint_control_mode = not self.joint_control_mode
|
||||
mode_str = "关节控制" if self.joint_control_mode else "末端位姿控制"
|
||||
print(f"切换到{mode_str}模式")
|
||||
time.sleep(0.3) # 防止多次触发
|
||||
|
||||
# 检查精细控制模式切换 (L3按钮)
|
||||
if self.joystick.get_button(10): # L3按钮
|
||||
self.fine_control_mode = not self.fine_control_mode
|
||||
print(f"切换到{'精细' if self.fine_control_mode else '普通'}控制模式")
|
||||
time.sleep(0.3) # 防止多次触发
|
||||
|
||||
# 检查重置按钮 (Start按钮)
|
||||
if self.joystick.get_button(7): # Start按钮
|
||||
print("重置机械臂到初始位置...")
|
||||
# print("init_joint", self.init_joint.copy())
|
||||
self.tozero = True
|
||||
self.joint = self.init_joint.copy()
|
||||
self.pose = self.init_pose.copy()
|
||||
self.pose_speeds = [0.0] * 6
|
||||
self.joint_speeds = [0.0] * 6
|
||||
self.gripper_speed = 10
|
||||
self.gripper = self.min_gripper
|
||||
print("机械臂已重置到初始位置")
|
||||
time.sleep(0.3) # 防止多次触发
|
||||
|
||||
# 从串口获取主臂关节状态
|
||||
if self.servo_arm and self.servo_arm.connected:
|
||||
try:
|
||||
self.master_joint_actions = self.servo_arm.get_joint_actions()
|
||||
self.master_controller_status = self.servo_arm.get_controller_status()
|
||||
if self.master_joint_actions:
|
||||
logging.debug(f"主臂关节状态: {self.master_joint_actions}")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"获取主臂状态错误: {e}")
|
||||
self.master_joint_actions = {}
|
||||
# print(self.master_joint_actions)
|
||||
|
||||
# 根据控制模式更新相应的控制逻辑
|
||||
if self.joint_control_mode:
|
||||
# 关节控制模式下,优先使用主臂数据,Xbox作为辅助
|
||||
self.update_joint_control()
|
||||
else:
|
||||
# 末端控制模式,使用Xbox控制
|
||||
self.update_end_pose()
|
||||
time.sleep(0.02)
|
||||
# print('gripper:', self.gripper)
|
||||
|
||||
def update_joint_control(self):
|
||||
"""更新关节角度控制 - 优先使用主臂数据"""
|
||||
if self.use_master_arm and self.master_joint_actions:
|
||||
# 主从控制模式:直接使用主臂的关节角度
|
||||
try:
|
||||
# 将主臂关节角度映射到从臂
|
||||
for i in range(6): # 假设只有6个关节需要控制
|
||||
joint_key = f"joint_{i+1}"
|
||||
if joint_key in self.master_joint_actions:
|
||||
# 直接使用主臂的关节角度(已经是度数)
|
||||
self.joint[i] = self.master_joint_actions[joint_key]
|
||||
|
||||
# 应用关节限制
|
||||
min_val, max_val = self.joint_limits[i]
|
||||
self.joint[i] = max(min_val, min(max_val, self.joint[i]))
|
||||
|
||||
# print(self.joint)
|
||||
logging.debug(f"主臂关节映射到从臂: {self.joint[:6]}")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"主臂数据映射错误: {e}")
|
||||
|
||||
# 如果有主臂夹爪数据,使用主臂夹爪状态
|
||||
if self.use_master_arm and "grasp" in self.master_joint_actions:
|
||||
self.gripper = self.master_joint_actions["grasp"] * 1000
|
||||
self.joint[-1] = self.gripper
|
||||
|
||||
|
||||
def update_end_pose(self):
|
||||
"""更新末端位姿控制"""
|
||||
# 根据控制模式调整步长
|
||||
current_linear_step = self.linear_step * (0.1 if self.fine_control_mode else 1.0)
|
||||
current_angular_step = self.angular_step * (0.1 if self.fine_control_mode else 1.0)
|
||||
|
||||
# 方向键控制XY
|
||||
hat = self.joystick.get_hat(0)
|
||||
hat_up = hat[1] == 1 # Y+
|
||||
hat_down = hat[1] == -1 # Y-
|
||||
hat_left = hat[0] == -1 # X-
|
||||
hat_right = hat[0] == 1 # X+
|
||||
|
||||
# 右摇杆控制Z
|
||||
right_y_raw = -self.joystick.get_axis(4)
|
||||
# 左摇杆控制RZ
|
||||
left_y_raw = -self.joystick.get_axis(1)
|
||||
|
||||
# 应用死区
|
||||
right_y = 0.0 if abs(right_y_raw) < self.deadzone else right_y_raw
|
||||
left_y = 0.0 if abs(left_y_raw) < self.deadzone else left_y_raw
|
||||
|
||||
# 计算各轴速度
|
||||
self.pose_speeds[1] = current_linear_step if hat_up else (-current_linear_step if hat_down else 0.0) # Y
|
||||
self.pose_speeds[0] = -current_linear_step if hat_left else (current_linear_step if hat_right else 0.0) # X
|
||||
|
||||
# 设置Z速度(右摇杆Y轴控制)
|
||||
z_mapping = self._apply_nonlinear_mapping(right_y)
|
||||
self.pose_speeds[2] = z_mapping * current_linear_step # Z
|
||||
|
||||
# L1/R1控制RX旋转
|
||||
LB = self.joystick.get_button(4) # RX-
|
||||
RB = self.joystick.get_button(5) # RX+
|
||||
self.pose_speeds[3] = (-current_angular_step if LB else (current_angular_step if RB else 0.0))
|
||||
|
||||
# △/□控制RY旋转
|
||||
triangle = self.joystick.get_button(2) # RY+
|
||||
square = self.joystick.get_button(3) # RY-
|
||||
self.pose_speeds[4] = (current_angular_step if triangle else (-current_angular_step if square else 0.0))
|
||||
|
||||
# 左摇杆Y轴控制RZ旋转
|
||||
rz_mapping = self._apply_nonlinear_mapping(left_y)
|
||||
self.pose_speeds[5] = rz_mapping * current_angular_step * 2 # RZ
|
||||
|
||||
# 夹爪控制(圈/叉)
|
||||
circle = self.joystick.get_button(1) # 夹爪开
|
||||
cross = self.joystick.get_button(0) # 夹爪关
|
||||
if circle:
|
||||
self.gripper = min(self.max_gripper, self.gripper + self.gripper_speed)
|
||||
elif cross:
|
||||
self.gripper = max(self.min_gripper, self.gripper - self.gripper_speed)
|
||||
|
||||
# 更新末端位姿
|
||||
for i in range(6):
|
||||
self.pose[i] += self.pose_speeds[i]
|
||||
|
||||
# 角度归一化处理
|
||||
for i in range(3, 6):
|
||||
self.pose[i] = self._normalize_angle(self.pose[i])
|
||||
|
||||
def update_joint_state(self, joint):
|
||||
self.joint = joint
|
||||
# self.tozero = False
|
||||
|
||||
def update_endpose_state(self, end_pose):
|
||||
self.pose = end_pose
|
||||
# self.tozero = False
|
||||
|
||||
def update_tozero_state(self, tozero):
|
||||
self.tozero = tozero
|
||||
|
||||
|
||||
def get_action(self) -> Dict:
|
||||
"""获取当前控制命令"""
|
||||
return {
|
||||
'control_mode': 'joint' if self.joint_control_mode else 'end_pose',
|
||||
'use_master_arm': self.use_master_arm,
|
||||
'master_joint_actions': self.master_joint_actions,
|
||||
'master_controller_status': self.master_controller_status,
|
||||
'end_pose': self.pose,
|
||||
'joint_angles': self.joint,
|
||||
'gripper': int(self.gripper),
|
||||
'tozero': self.tozero
|
||||
}
|
||||
|
||||
def stop(self):
|
||||
"""停止控制器"""
|
||||
self.running = False
|
||||
if self.thread.is_alive():
|
||||
self.thread.join()
|
||||
if self.servo_arm:
|
||||
self.servo_arm.close()
|
||||
pygame.quit()
|
||||
print("混合控制器已退出")
|
||||
|
||||
def reset(self):
|
||||
"""重置到初始状态"""
|
||||
self.joint = self.init_joint.copy()
|
||||
self.pose = self.init_pose.copy()
|
||||
self.pose_speeds = [0.0] * 6
|
||||
self.joint_speeds = [0.0] * 6
|
||||
self.gripper_speed = 10
|
||||
self.gripper = self.min_gripper
|
||||
print("已重置到初始状态")
|
||||
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
# 初始化睿尔曼机械臂
|
||||
arm = RoboticArm(rm_thread_mode_e.RM_TRIPLE_MODE_E)
|
||||
# 创建机械臂连接
|
||||
handle = arm.rm_create_robot_arm("192.168.3.18", 8080)
|
||||
print(f"机械臂连接ID: {handle.id}")
|
||||
init_pose = arm.rm_get_current_arm_state()[1]['pose']
|
||||
|
||||
with open('/home/maic/LYT/lerobot/lerobot/common/robot_devices/teleop/realman_mix.yaml', "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
config['init_pose'] = init_pose
|
||||
arm_controller = HybridController(config)
|
||||
try:
|
||||
while True:
|
||||
print(arm_controller.get_action())
|
||||
time.sleep(0.1)
|
||||
except KeyboardInterrupt:
|
||||
arm_controller.stop()
|
||||
4
lerobot/common/robot_devices/teleop/realman_mix.yaml
Normal file
4
lerobot/common/robot_devices/teleop/realman_mix.yaml
Normal file
@@ -0,0 +1,4 @@
|
||||
init_joint: [-90, 90, 90, -90, -90, 90]
|
||||
max_gripper: 990
|
||||
min_gripper: 10
|
||||
servo_config_file: "/home/maic/LYT/lerobot/lerobot/common/robot_devices/teleop/servo_arm.yaml"
|
||||
6
lerobot/common/robot_devices/teleop/servo_arm.yaml
Normal file
6
lerobot/common/robot_devices/teleop/servo_arm.yaml
Normal file
@@ -0,0 +1,6 @@
|
||||
port: /dev/ttyUSB0
|
||||
right_port: /dev/ttyUSB1
|
||||
baudrate: 460800
|
||||
joint_hex_data: "55 AA 02 00 00 67"
|
||||
control_hex_data: "55 AA 08 00 00 B9"
|
||||
arm_axis: 6
|
||||
@@ -175,7 +175,8 @@ def say(text, blocking=False):
|
||||
cmd = ["say", text]
|
||||
|
||||
elif system == "Linux":
|
||||
cmd = ["spd-say", text]
|
||||
# cmd = ["spd-say", text]
|
||||
cmd = ["edge-playback", "-t", text]
|
||||
if blocking:
|
||||
cmd.append("--wait")
|
||||
|
||||
|
||||
@@ -273,7 +273,6 @@ def record(
|
||||
|
||||
# Load pretrained policy
|
||||
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
@@ -290,6 +289,9 @@ def record(
|
||||
if has_method(robot, "teleop_safety_stop"):
|
||||
robot.teleop_safety_stop()
|
||||
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
|
||||
recorded_episodes = 0
|
||||
while True:
|
||||
if recorded_episodes >= cfg.num_episodes:
|
||||
|
||||
@@ -63,13 +63,13 @@ dependencies = [
|
||||
"opencv-python-headless>=4.9.0",
|
||||
"packaging>=24.2",
|
||||
"av>=14.2.0",
|
||||
"pymunk>=6.6.0",
|
||||
"pymunk>=6.6.0,<7.0.0",
|
||||
"pynput>=1.7.7",
|
||||
"pyzmq>=26.2.1",
|
||||
"rerun-sdk>=0.21.0",
|
||||
"termcolor>=2.4.0",
|
||||
"torch>=2.2.1,<2.7",
|
||||
"torchcodec==0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
|
||||
"torch>=2.2.1",
|
||||
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
|
||||
"torchvision>=0.21.0",
|
||||
"wandb>=0.16.3",
|
||||
"zarr>=2.17.0",
|
||||
@@ -86,6 +86,7 @@ dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"]
|
||||
feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"]
|
||||
intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"]
|
||||
pi0 = ["transformers>=4.48.0"]
|
||||
smolvla = ["transformers>=4.50.3", "num2words>=0.5.14", "accelerate>=1.7.0"]
|
||||
pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"]
|
||||
stretch = [
|
||||
"hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'",
|
||||
|
||||
156
realman.md
Normal file
156
realman.md
Normal file
@@ -0,0 +1,156 @@
|
||||
# Install
|
||||
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
conda activate lerobot
|
||||
```
|
||||
|
||||
Install 🤗 LeRobot:
|
||||
```bash
|
||||
pip install -e . -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
pip install edge-tts
|
||||
sudo apt install mpv -y
|
||||
|
||||
# pip uninstall numpy
|
||||
# pip install numpy==1.26.0
|
||||
# pip install pynput
|
||||
```
|
||||
|
||||
/!\ For Linux only, ffmpeg and opencv requires conda install for now. Run this exact sequence of commands:
|
||||
```bash
|
||||
conda install ffmpeg=7.1.1 -c conda-forge
|
||||
# pip uninstall opencv-python
|
||||
# conda install "opencv>=4.10.0"
|
||||
```
|
||||
|
||||
Install Realman SDK:
|
||||
```bash
|
||||
pip install Robotic_Arm==1.0.4.1
|
||||
pip install pygame
|
||||
```
|
||||
|
||||
# piper集成lerobot
|
||||
见lerobot_piper_tutorial/1. 🤗 LeRobot:新增机械臂的一般流程.pdf
|
||||
|
||||
# Teleoperate
|
||||
```bash
|
||||
cd piper_scripts/
|
||||
bash can_activate.sh can0 1000000
|
||||
|
||||
cd ..
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=piper \
|
||||
--robot.inference_time=false \
|
||||
--control.type=teleoperate
|
||||
```
|
||||
|
||||
# Record
|
||||
Set dataset root path
|
||||
```bash
|
||||
HF_USER=$PWD/data
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=realman \
|
||||
--robot.inference_time=false \
|
||||
--control.type=record \
|
||||
--control.fps=15 \
|
||||
--control.single_task="move" \
|
||||
--control.repo_id=maic/test \
|
||||
--control.num_episodes=2 \
|
||||
--control.warmup_time_s=2 \
|
||||
--control.episode_time_s=10 \
|
||||
--control.reset_time_s=10 \
|
||||
--control.play_sounds=true \
|
||||
--control.push_to_hub=false \
|
||||
--control.display_data=true
|
||||
```
|
||||
|
||||
Press right arrow -> at any time during episode recording to early stop and go to resetting. Same during resetting, to early stop and to go to the next episode recording.
|
||||
Press left arrow <- at any time during episode recording or resetting to early stop, cancel the current episode, and re-record it.
|
||||
Press escape ESC at any time during episode recording to end the session early and go straight to video encoding and dataset uploading.
|
||||
|
||||
# visualize
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset.py \
|
||||
--repo-id ${HF_USER}/test \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
# Replay
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=piper \
|
||||
--robot.inference_time=false \
|
||||
--control.type=replay \
|
||||
--control.fps=30 \
|
||||
--control.repo_id=${HF_USER}/test \
|
||||
--control.episode=0
|
||||
```
|
||||
|
||||
# Caution
|
||||
|
||||
1. In lerobots/common/datasets/video_utils, the vcodec is set to **libopenh264**, please find your vcodec by **ffmpeg -codecs**
|
||||
|
||||
|
||||
# Train
|
||||
具体的训练流程见lerobot_piper_tutorial/2. 🤗 AutoDL训练.pdf
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--dataset.repo_id=${HF_USER}/jack \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_jack \
|
||||
--job_name=act_jack \
|
||||
--device=cuda \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
# FT smolvla
|
||||
python lerobot/scripts/train.py \
|
||||
--dataset.repo_id=maic/move_the_bottle_into_ultrasonic_device_with_realman_single \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--output_dir=outputs/train/smolvla_move_the_bottle_into_ultrasonic_device_with_realman_single \
|
||||
--job_name=smolvla_move_the_bottle_into_ultrasonic_device_with_realman_single \
|
||||
--policy.device=cuda \
|
||||
--wandb.enable=false \
|
||||
--steps=200000 \
|
||||
--batch_size=16
|
||||
|
||||
|
||||
# Inference
|
||||
还是使用control_robot.py中的record loop,配置 **--robot.inference_time=true** 可以将手柄移出。
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=realman \
|
||||
--robot.inference_time=true \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="move the bottle into ultrasonic device with realman single" \
|
||||
--control.repo_id=maic/move_the_bottle_into_ultrasonic_device_with_realman_single \
|
||||
--control.num_episodes=1 \
|
||||
--control.warmup_time_s=2 \
|
||||
--control.episode_time_s=30 \
|
||||
--control.reset_time_s=10 \
|
||||
--control.push_to_hub=false \
|
||||
--control.policy.path=outputs/train/act_move_the_bottle_into_ultrasonic_device_with_realman_single/checkpoints/100000/pretrained_model
|
||||
```
|
||||
|
||||
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py \
|
||||
--robot.type=realman \
|
||||
--robot.inference_time=true \
|
||||
--control.type=record \
|
||||
--control.fps=30 \
|
||||
--control.single_task="move the bottle into ultrasonic device with realman single" \
|
||||
--control.repo_id=maic/eval_smolvla_move_the_bottle_into_ultrasonic_device_with_realman_single \
|
||||
--control.num_episodes=1 \
|
||||
--control.warmup_time_s=2 \
|
||||
--control.episode_time_s=60 \
|
||||
--control.reset_time_s=10 \
|
||||
--control.push_to_hub=false \
|
||||
--control.policy.path=outputs/train/smolvla_move_the_bottle_into_ultrasonic_device_with_realman_single/checkpoints/160000/pretrained_model \
|
||||
--control.display_data=true
|
||||
```
|
||||
31
realman_src/dual_arm_connect_test.py
Normal file
31
realman_src/dual_arm_connect_test.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from Robotic_Arm.rm_robot_interface import *
|
||||
|
||||
armleft = RoboticArm(rm_thread_mode_e.RM_TRIPLE_MODE_E)
|
||||
armright = RoboticArm()
|
||||
|
||||
|
||||
lefthandle = armleft.rm_create_robot_arm("169.254.128.18", 8080)
|
||||
print("机械臂ID:", lefthandle.id)
|
||||
righthandle = armright.rm_create_robot_arm("169.254.128.19", 8080)
|
||||
print("机械臂ID:", righthandle.id)
|
||||
|
||||
# software_info = armleft.rm_get_arm_software_info()
|
||||
# if software_info[0] == 0:
|
||||
# print("\n================== Arm Software Information ==================")
|
||||
# print("Arm Model: ", software_info[1]['product_version'])
|
||||
# print("Algorithm Library Version: ", software_info[1]['algorithm_info']['version'])
|
||||
# print("Control Layer Software Version: ", software_info[1]['ctrl_info']['version'])
|
||||
# print("Dynamics Version: ", software_info[1]['dynamic_info']['model_version'])
|
||||
# print("Planning Layer Software Version: ", software_info[1]['plan_info']['version'])
|
||||
# print("==============================================================\n")
|
||||
# else:
|
||||
# print("\nFailed to get arm software information, Error code: ", software_info[0], "\n")
|
||||
|
||||
print("Left: ", armleft.rm_get_current_arm_state())
|
||||
print("Left: ", armleft.rm_get_arm_all_state())
|
||||
armleft.rm_movej_p()
|
||||
# print("Right: ", armright.rm_get_current_arm_state())
|
||||
|
||||
|
||||
# 断开所有连接,销毁线程
|
||||
RoboticArm.rm_destory()
|
||||
15
realman_src/movep_canfd.py
Normal file
15
realman_src/movep_canfd.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from Robotic_Arm.rm_robot_interface import *
|
||||
import time
|
||||
|
||||
# 实例化RoboticArm类
|
||||
arm = RoboticArm(rm_thread_mode_e.RM_TRIPLE_MODE_E)
|
||||
# 创建机械臂连接,打印连接id
|
||||
handle = arm.rm_create_robot_arm("192.168.3.18", 8080)
|
||||
print(handle.id)
|
||||
|
||||
print(arm.rm_movep_follow([-0.330512, 0.255993, -0.161205, 3.141, 0.0, -1.57]))
|
||||
time.sleep(2)
|
||||
# print(arm.rm_movep_follow([0.3, 0, 0.3, 3.14, 0, 0]))
|
||||
# time.sleep(2)
|
||||
|
||||
arm.rm_delete_robot_arm()
|
||||
0
realman_src/realman_aloha/__init__.py
Normal file
0
realman_src/realman_aloha/__init__.py
Normal file
4
realman_src/realman_aloha/shadow_camera/.gitignore
vendored
Normal file
4
realman_src/realman_aloha/shadow_camera/.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pt
|
||||
0
realman_src/realman_aloha/shadow_camera/README.md
Normal file
0
realman_src/realman_aloha/shadow_camera/README.md
Normal file
0
realman_src/realman_aloha/shadow_camera/__init__.py
Normal file
0
realman_src/realman_aloha/shadow_camera/__init__.py
Normal file
33
realman_src/realman_aloha/shadow_camera/pyproject.toml
Normal file
33
realman_src/realman_aloha/shadow_camera/pyproject.toml
Normal file
@@ -0,0 +1,33 @@
|
||||
[tool.poetry]
|
||||
name = "shadow_camera"
|
||||
version = "0.1.0"
|
||||
description = "camera class, currently includes realsense"
|
||||
readme = "README.md"
|
||||
authors = ["Shadow <qiuchengzhan@gmail.com>"]
|
||||
license = "MIT"
|
||||
#include = ["realman_vision/pytransform/_pytransform.so",]
|
||||
classifiers = [
|
||||
"Operating System :: POSIX :: Linux amd64",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.9"
|
||||
numpy = ">=2.0.1"
|
||||
opencv-python = ">=4.10.0.84"
|
||||
pyrealsense2 = ">=2.55.1.6486"
|
||||
|
||||
[tool.poetry.dev-dependencies] # 列出开发时所需的依赖项,比如测试、文档生成等工具。
|
||||
pytest = ">=8.3"
|
||||
black = ">=24.10.0"
|
||||
|
||||
[tool.poetry.plugins."scripts"] # 定义命令行脚本,使得用户可以通过命令行运行指定的函数。
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.8.4"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = '0.1.0'
|
||||
@@ -0,0 +1,38 @@
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
|
||||
class BaseCamera(metaclass=ABCMeta):
|
||||
"""摄像头基类"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start_camera(self):
|
||||
"""启动相机"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stop_camera(self):
|
||||
"""停止相机"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_resolution(self, resolution_width, resolution_height):
|
||||
"""设置相机彩色图像分辨率"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_frame_rate(self, fps):
|
||||
"""设置相机彩色图像帧率"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def read_frame(self):
|
||||
"""读取一帧彩色图像和深度图像"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_camera_intrinsics(self):
|
||||
"""获取彩色图像和深度图像的内参"""
|
||||
pass
|
||||
Binary file not shown.
@@ -0,0 +1,38 @@
|
||||
from shadow_camera import base_camera
|
||||
import cv2
|
||||
|
||||
class OpenCVCamera(base_camera.BaseCamera):
|
||||
"""基于OpenCV的摄像头类"""
|
||||
|
||||
def __init__(self, device_id=0):
|
||||
"""初始化视频捕获
|
||||
|
||||
参数:
|
||||
device_id: 摄像头设备ID
|
||||
"""
|
||||
self.cap = cv2.VideoCapture(device_id)
|
||||
|
||||
def get_frame(self):
|
||||
"""获取当前帧
|
||||
|
||||
返回:
|
||||
frame: 当前帧的图像数据,取不到时返回None
|
||||
"""
|
||||
ret, frame = self.cap.read()
|
||||
return frame if ret else None
|
||||
|
||||
def get_frame_info(self):
|
||||
"""获取当前帧信息
|
||||
|
||||
返回:
|
||||
dict: 帧信息字典
|
||||
"""
|
||||
width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
channels = int(self.cap.get(cv2.CAP_PROP_FRAME_CHANNELS))
|
||||
|
||||
return {
|
||||
'width': width,
|
||||
'height': height,
|
||||
'channels': channels
|
||||
}
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,280 @@
|
||||
import time
|
||||
import logging
|
||||
import numpy as np
|
||||
import pyrealsense2 as rs
|
||||
import base_camera
|
||||
|
||||
# 设置日志配置
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
|
||||
class RealSenseCamera(base_camera.BaseCamera):
|
||||
"""Intel RealSense相机类"""
|
||||
|
||||
def __init__(self, serial_num=None, is_depth_frame=False):
|
||||
"""
|
||||
初始化相机对象
|
||||
:param serial_num: 相机序列号,默认为None
|
||||
"""
|
||||
super().__init__()
|
||||
self._color_resolution = [640, 480]
|
||||
self._depth_resolution = [640, 480]
|
||||
self._color_frames_rate = 30
|
||||
self._depth_frames_rate = 15
|
||||
self.timestamp = 0
|
||||
self.color_timestamp = 0
|
||||
self.depth_timestamp = 0
|
||||
self._colorizer = rs.colorizer()
|
||||
self._config = rs.config()
|
||||
self.is_depth_frame = is_depth_frame
|
||||
self.camera_on = False
|
||||
self.serial_num = serial_num
|
||||
|
||||
def get_serial_num(self):
|
||||
serial_num = {}
|
||||
context = rs.context()
|
||||
devices = context.query_devices() # 获取所有设备
|
||||
if len(context.devices) > 0:
|
||||
for i, device in enumerate(devices):
|
||||
serial_num[i] = device.get_info(rs.camera_info.serial_number)
|
||||
|
||||
logging.info(f"Detected serial numbers: {serial_num}")
|
||||
return serial_num
|
||||
|
||||
def _set_config(self):
|
||||
if self.serial_num is not None:
|
||||
logging.info(f"Setting device with serial number: {self.serial_num}")
|
||||
self._config.enable_device(self.serial_num)
|
||||
|
||||
self._config.enable_stream(
|
||||
rs.stream.color,
|
||||
self._color_resolution[0],
|
||||
self._color_resolution[1],
|
||||
rs.format.rgb8,
|
||||
self._color_frames_rate,
|
||||
)
|
||||
if self.is_depth_frame:
|
||||
self._config.enable_stream(
|
||||
rs.stream.depth,
|
||||
self._depth_resolution[0],
|
||||
self._depth_resolution[1],
|
||||
rs.format.z16,
|
||||
self._depth_frames_rate,
|
||||
)
|
||||
|
||||
def start_camera(self):
|
||||
"""
|
||||
启动相机并获取内参信息,如果后续调用帧对齐,则内参均为彩色内参
|
||||
"""
|
||||
self._pipeline = rs.pipeline()
|
||||
if self.is_depth_frame:
|
||||
self.point_cloud = rs.pointcloud()
|
||||
self._align = rs.align(rs.stream.color)
|
||||
self._set_config()
|
||||
|
||||
self.profile = self._pipeline.start(self._config)
|
||||
|
||||
if self.is_depth_frame:
|
||||
self._depth_intrinsics = (
|
||||
self.profile.get_stream(rs.stream.depth)
|
||||
.as_video_stream_profile()
|
||||
.get_intrinsics()
|
||||
)
|
||||
|
||||
self._color_intrinsics = (
|
||||
self.profile.get_stream(rs.stream.color)
|
||||
.as_video_stream_profile()
|
||||
.get_intrinsics()
|
||||
)
|
||||
self.camera_on = True
|
||||
logging.info("Camera started successfully")
|
||||
logging.info(
|
||||
f"Camera started with color resolution: {self._color_resolution}, depth resolution: {self._depth_resolution}"
|
||||
)
|
||||
logging.info(
|
||||
f"Color FPS: {self._color_frames_rate}, Depth FPS: {self._depth_frames_rate}"
|
||||
)
|
||||
|
||||
def stop_camera(self):
|
||||
"""
|
||||
停止相机
|
||||
"""
|
||||
self._pipeline.stop()
|
||||
self.camera_on = False
|
||||
logging.info("Camera stopped")
|
||||
|
||||
def set_resolution(self, color_resolution, depth_resolution):
|
||||
self._color_resolution = color_resolution
|
||||
self._depth_resolution = depth_resolution
|
||||
logging.info(
|
||||
"Optional color resolution:"
|
||||
"[320, 180] [320, 240] [424, 240] [640, 360] [640, 480]"
|
||||
"[848, 480] [960, 540] [1280, 720] [1920, 1080]"
|
||||
)
|
||||
logging.info(
|
||||
"Optional depth resolution:"
|
||||
"[256, 144] [424, 240] [480, 270] [640, 360] [640, 400]"
|
||||
"[640, 480] [848, 100] [848, 480] [1280, 720] [1280, 800]"
|
||||
)
|
||||
logging.info(f"Set color resolution to: {color_resolution}")
|
||||
logging.info(f"Set depth resolution to: {depth_resolution}")
|
||||
|
||||
def set_frame_rate(self, color_fps, depth_fps):
|
||||
self._color_frames_rate = color_fps
|
||||
self._depth_frames_rate = depth_fps
|
||||
logging.info("Optional color fps: 6 15 30 60 ")
|
||||
logging.info("Optional depth fps: 6 15 30 60 90 100 300")
|
||||
logging.info(f"Set color FPS to: {color_fps}")
|
||||
logging.info(f"Set depth FPS to: {depth_fps}")
|
||||
|
||||
# TODO: 调节白平衡进行补偿
|
||||
# def set_exposure(self, exposure):
|
||||
|
||||
def read_frame(self, is_color=True, is_depth=True, is_colorized_depth=False, is_point_cloud=False):
|
||||
"""
|
||||
读取一帧彩色图像和深度图像
|
||||
:return: 彩色图像和深度图像的NumPy数组
|
||||
"""
|
||||
while not self.camera_on:
|
||||
time.sleep(0.5)
|
||||
color_image = None
|
||||
depth_image = None
|
||||
colorized_depth = None
|
||||
point_cloud = None
|
||||
try:
|
||||
frames = self._pipeline.wait_for_frames()
|
||||
if is_color:
|
||||
color_frame = frames.get_color_frame()
|
||||
color_image = np.asanyarray(color_frame.get_data())
|
||||
else:
|
||||
color_image = None
|
||||
|
||||
if is_depth:
|
||||
depth_frame = frames.get_depth_frame()
|
||||
depth_image = np.asanyarray(depth_frame.get_data())
|
||||
else:
|
||||
depth_image = None
|
||||
|
||||
colorized_depth = (
|
||||
np.asanyarray(self._colorizer.colorize(depth_frame).get_data())
|
||||
if is_colorized_depth
|
||||
else None
|
||||
)
|
||||
point_cloud = (
|
||||
np.asanyarray(self.point_cloud.calculate(depth_frame).get_vertices())
|
||||
if is_point_cloud
|
||||
else None
|
||||
)
|
||||
# 获取时间戳单位为ms,对齐后color时间戳 > depth = aligned,选择color
|
||||
self.color_timestamp = color_frame.get_timestamp()
|
||||
if self.is_depth_frame:
|
||||
self.depth_timestamp = depth_frame.get_timestamp()
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(e)
|
||||
if "Frame didn't arrive within 5000" in str(e):
|
||||
logging.warning("Frame didn't arrive within 5000ms, resetting device")
|
||||
self.stop_camera()
|
||||
self.start_camera()
|
||||
|
||||
return color_image, depth_image, colorized_depth, point_cloud
|
||||
|
||||
def read_align_frame(self, is_color=True, is_depth=True, is_colorized_depth=False, is_point_cloud=False):
|
||||
"""
|
||||
读取一帧对齐的彩色图像和深度图像
|
||||
:return: 彩色图像和深度图像的NumPy数组
|
||||
"""
|
||||
while not self.camera_on:
|
||||
time.sleep(0.5)
|
||||
try:
|
||||
frames = self._pipeline.wait_for_frames()
|
||||
aligned_frames = self._align.process(frames)
|
||||
aligned_color_frame = aligned_frames.get_color_frame()
|
||||
self._aligned_depth_frame = aligned_frames.get_depth_frame()
|
||||
|
||||
color_image = np.asanyarray(aligned_color_frame.get_data())
|
||||
depth_image = np.asanyarray(self._aligned_depth_frame.get_data())
|
||||
colorized_depth = (
|
||||
np.asanyarray(
|
||||
self._colorizer.colorize(self._aligned_depth_frame).get_data()
|
||||
)
|
||||
if is_colorized_depth
|
||||
else None
|
||||
)
|
||||
|
||||
if is_point_cloud:
|
||||
points = self.point_cloud.calculate(self._aligned_depth_frame)
|
||||
# 将元组数据转换为 NumPy 数组
|
||||
point_cloud = np.array(
|
||||
[[point[0], point[1], point[2]] for point in points.get_vertices()]
|
||||
)
|
||||
else:
|
||||
point_cloud = None
|
||||
|
||||
# 获取时间戳单位为ms,对齐后color时间戳 > depth = aligned,选择color
|
||||
self.timestamp = aligned_color_frame.get_timestamp()
|
||||
|
||||
return color_image, depth_image, colorized_depth, point_cloud
|
||||
|
||||
except Exception as e:
|
||||
if "Frame didn't arrive within 5000" in str(e):
|
||||
logging.warning("Frame didn't arrive within 5000ms, resetting device")
|
||||
self.stop_camera()
|
||||
self.start_camera()
|
||||
# device = self.profile.get_device()
|
||||
# device.hardware_reset()
|
||||
|
||||
def get_camera_intrinsics(self):
|
||||
"""
|
||||
获取彩色图像和深度图像的内参信息
|
||||
:return: 彩色图像和深度图像的内参信息
|
||||
"""
|
||||
# 宽高:.width, .height; 焦距:.fx, .fy; 像素坐标:.ppx, .ppy; 畸变系数:.coeffs
|
||||
logging.info("Getting camera intrinsics")
|
||||
logging.info(
|
||||
"Width and height: .width, .height; Focal length: .fx, .fy; Pixel coordinates: .ppx, .ppy; Distortion coefficient: .coeffs"
|
||||
)
|
||||
return self._color_intrinsics, self._depth_intrinsics
|
||||
|
||||
def get_3d_camera_coordinate(self, depth_pixel, align=False):
|
||||
"""
|
||||
获取深度相机坐标系下的三维坐标
|
||||
:param depth_pixel:深度像素坐标
|
||||
:param align: 是否对齐
|
||||
|
||||
:return: 深度值和相机坐标
|
||||
"""
|
||||
if not hasattr(self, "_aligned_depth_frame"):
|
||||
raise AttributeError(
|
||||
"Aligned depth frame not set. Call read_align_frame() first."
|
||||
)
|
||||
|
||||
distance = self._aligned_depth_frame.get_distance(
|
||||
depth_pixel[0], depth_pixel[1]
|
||||
)
|
||||
intrinsics = self._color_intrinsics if align else self._depth_intrinsics
|
||||
camera_coordinate = rs.rs2_deproject_pixel_to_point(
|
||||
intrinsics, depth_pixel, distance
|
||||
)
|
||||
return distance, camera_coordinate
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
camera = RealSenseCamera(is_depth_frame=False)
|
||||
camera.get_serial_num()
|
||||
camera.start_camera()
|
||||
# camera.set_frame_rate(60, 60)
|
||||
color_image, depth_image, colorized_depth, point_cloud = camera.read_frame()
|
||||
camera.stop_camera()
|
||||
logging.info(f"Color image shape: {color_image.shape}")
|
||||
# logging.info(f"Depth image shape: {depth_image.shape}")
|
||||
# logging.info(f"Colorized depth image shape: {colorized_depth.shape}")
|
||||
# logging.info(f"Point cloud shape: {point_cloud.shape}")
|
||||
logging.info(f"Color timestamp: {camera.timestamp}")
|
||||
# logging.info(f"Depth timestamp: {camera.depth_timestamp}")
|
||||
logging.info(f"Color timestamp: {camera.color_timestamp}")
|
||||
# logging.info(f"Depth timestamp: {camera.depth_timestamp}")
|
||||
logging.info("Test passed")
|
||||
@@ -0,0 +1,101 @@
|
||||
import pyrealsense2 as rs
|
||||
import numpy as np
|
||||
import h5py
|
||||
import time
|
||||
import threading
|
||||
import keyboard # 用于监听键盘输入
|
||||
|
||||
# 全局变量
|
||||
is_recording = False # 标志位,控制录制状态
|
||||
color_images = [] # 存储彩色图像
|
||||
depth_images = [] # 存储深度图像
|
||||
timestamps = [] # 存储时间戳
|
||||
|
||||
# 配置D435相机
|
||||
def configure_camera():
|
||||
pipeline = rs.pipeline()
|
||||
config = rs.config()
|
||||
config.enable_stream(rs.stream.color, 640, 480, rs.format.bgr8, 30) # 彩色图像流
|
||||
config.enable_stream(rs.stream.depth, 640, 480, rs.format.z16, 30) # 深度图像流
|
||||
pipeline.start(config)
|
||||
return pipeline
|
||||
|
||||
# 监听键盘输入,控制录制状态
|
||||
def listen_for_keyboard():
|
||||
global is_recording
|
||||
while True:
|
||||
if keyboard.is_pressed('s'): # 按下 's' 开始录制
|
||||
is_recording = True
|
||||
print("Recording started.")
|
||||
time.sleep(0.5) # 防止重复触发
|
||||
elif keyboard.is_pressed('q'): # 按下 'q' 停止录制
|
||||
is_recording = False
|
||||
print("Recording stopped.")
|
||||
time.sleep(0.5) # 防止重复触发
|
||||
elif keyboard.is_pressed('e'): # 按下 'e' 退出程序
|
||||
print("Exiting program.")
|
||||
exit()
|
||||
time.sleep(0.1)
|
||||
|
||||
# 采集图像数据
|
||||
def capture_frames(pipeline):
|
||||
global is_recording, color_images, depth_images, timestamps
|
||||
try:
|
||||
while True:
|
||||
if is_recording:
|
||||
frames = pipeline.wait_for_frames()
|
||||
color_frame = frames.get_color_frame()
|
||||
depth_frame = frames.get_depth_frame()
|
||||
|
||||
if not color_frame or not depth_frame:
|
||||
continue
|
||||
|
||||
# 获取当前时间戳
|
||||
timestamp = time.time()
|
||||
|
||||
# 将图像转换为numpy数组
|
||||
color_image = np.asanyarray(color_frame.get_data())
|
||||
depth_image = np.asanyarray(depth_frame.get_data())
|
||||
|
||||
# 存储数据
|
||||
color_images.append(color_image)
|
||||
depth_images.append(depth_image)
|
||||
timestamps.append(timestamp)
|
||||
|
||||
print(f"Captured frame at {timestamp}")
|
||||
|
||||
else:
|
||||
time.sleep(0.1) # 如果未录制,等待一段时间
|
||||
|
||||
finally:
|
||||
pipeline.stop()
|
||||
|
||||
# 保存为HDF5文件
|
||||
def save_to_hdf5(color_images, depth_images, timestamps, filename="output.h5"):
|
||||
with h5py.File(filename, "w") as f:
|
||||
f.create_dataset("color_images", data=np.array(color_images), compression="gzip")
|
||||
f.create_dataset("depth_images", data=np.array(depth_images), compression="gzip")
|
||||
f.create_dataset("timestamps", data=np.array(timestamps), compression="gzip")
|
||||
print(f"Data saved to {filename}")
|
||||
|
||||
# 主函数
|
||||
def main():
|
||||
global is_recording, color_images, depth_images, timestamps
|
||||
|
||||
# 启动键盘监听线程
|
||||
keyboard_thread = threading.Thread(target=listen_for_keyboard)
|
||||
keyboard_thread.daemon = True
|
||||
keyboard_thread.start()
|
||||
|
||||
# 配置相机
|
||||
pipeline = configure_camera()
|
||||
|
||||
# 开始采集图像
|
||||
capture_frames(pipeline)
|
||||
|
||||
# 录制结束后保存数据
|
||||
if color_images and depth_images and timestamps:
|
||||
save_to_hdf5(color_images, depth_images, timestamps, "mobile_aloha_data.h5")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
152
realman_src/realman_aloha/shadow_camera/test/test_camera.py
Normal file
152
realman_src/realman_aloha/shadow_camera/test/test_camera.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import os
|
||||
import cv2
|
||||
import time
|
||||
import numpy as np
|
||||
from os import path
|
||||
import pyrealsense2 as rs
|
||||
from shadow_camera import realsense
|
||||
import logging
|
||||
|
||||
|
||||
|
||||
def test_camera():
|
||||
camera = realsense.RealSenseCamera('241122071186')
|
||||
camera.start_camera()
|
||||
|
||||
while True:
|
||||
# result = camera.read_align_frame()
|
||||
# if result is None:
|
||||
# print('is None')
|
||||
# continue
|
||||
# start_time = time.time()
|
||||
color_image, depth_image, colorized_depth, vtx = camera.read_frame()
|
||||
color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR)
|
||||
|
||||
print(f"color_image: {color_image.shape}")
|
||||
# print(f"Time: {end_time - start_time}")
|
||||
cv2.imshow("bgr_image", color_image)
|
||||
|
||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||
break
|
||||
camera.stop_camera()
|
||||
|
||||
|
||||
def test_get_serial_num():
|
||||
camera = realsense.RealSenseCamera()
|
||||
device = camera.get_serial_num()
|
||||
|
||||
|
||||
class CameraCapture:
|
||||
def __init__(self, camera_serial_num=None, save_dir="./save"):
|
||||
self._camera_serial_num = camera_serial_num
|
||||
self._color_save_dir = path.join(save_dir, "color")
|
||||
self._depth_save_dir = path.join(save_dir, "depth")
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
os.makedirs(self._color_save_dir, exist_ok=True)
|
||||
os.makedirs(self._depth_save_dir, exist_ok=True)
|
||||
|
||||
def get_serial_num(self):
|
||||
self._camera_serial_num = {}
|
||||
camera_names = ["left", "right", "head", "table"]
|
||||
context = rs.context()
|
||||
devices = context.query_devices() # 获取所有设备
|
||||
if len(context.devices) > 0:
|
||||
for i, device in enumerate(devices):
|
||||
self._camera_serial_num[camera_names[i]] = device.get_info(
|
||||
rs.camera_info.serial_number
|
||||
)
|
||||
print(self._camera_serial_num)
|
||||
|
||||
return self._camera_serial_num
|
||||
|
||||
def start_camera(self):
|
||||
if self._camera_serial_num is None:
|
||||
self.get_serial_num()
|
||||
self._camera_left = realsense.RealSenseCamera(self._camera_serial_num["left"])
|
||||
self._camera_right = realsense.RealSenseCamera(self._camera_serial_num["right"])
|
||||
self._camera_head = realsense.RealSenseCamera(self._camera_serial_num["head"])
|
||||
|
||||
self._camera_left.start_camera()
|
||||
self._camera_right.start_camera()
|
||||
self._camera_head.start_camera()
|
||||
|
||||
def stop_camera(self):
|
||||
self._camera_left.stop_camera()
|
||||
self._camera_right.stop_camera()
|
||||
self._camera_head.stop_camera()
|
||||
|
||||
def _save_datas(self, timestamp, color_image, depth_image, camera_name):
|
||||
color_filename = path.join(
|
||||
self._color_save_dir, f"{timestamp}" + camera_name + ".jpg"
|
||||
)
|
||||
depth_filename = path.join(
|
||||
self._depth_save_dir, f"{timestamp}" + camera_name + ".png"
|
||||
)
|
||||
cv2.imwrite(color_filename, color_image)
|
||||
cv2.imwrite(depth_filename, depth_image)
|
||||
|
||||
def capture_images(self):
|
||||
while True:
|
||||
(
|
||||
color_image_left,
|
||||
depth_image_left,
|
||||
_,
|
||||
_,
|
||||
) = self._camera_left.read_align_frame()
|
||||
(
|
||||
color_image_right,
|
||||
depth_image_right,
|
||||
_,
|
||||
_,
|
||||
) = self._camera_right.read_align_frame()
|
||||
(
|
||||
color_image_head,
|
||||
depth_image_head,
|
||||
_,
|
||||
point_cloud3,
|
||||
) = self._camera_head.read_align_frame()
|
||||
|
||||
bgr_color_image_left = cv2.cvtColor(color_image_left, cv2.COLOR_RGB2BGR)
|
||||
bgr_color_image_right = cv2.cvtColor(color_image_right, cv2.COLOR_RGB2BGR)
|
||||
bgr_color_image_head = cv2.cvtColor(color_image_head, cv2.COLOR_RGB2BGR)
|
||||
|
||||
timestamp = time.time() * 1000
|
||||
|
||||
cv2.imshow("Camera left", bgr_color_image_left)
|
||||
cv2.imshow("Camera right", bgr_color_image_right)
|
||||
cv2.imshow("Camera head", bgr_color_image_head)
|
||||
|
||||
# self._save_datas(
|
||||
# timestamp, bgr_color_image_left, depth_image_left, "left"
|
||||
# )
|
||||
# self._save_datas(
|
||||
# timestamp, bgr_color_image_right, depth_image_right, "right"
|
||||
# )
|
||||
# self._save_datas(
|
||||
# timestamp, bgr_color_image_head, depth_image_head, "head"
|
||||
# )
|
||||
|
||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||
break
|
||||
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
#test_camera()
|
||||
test_get_serial_num()
|
||||
"""
|
||||
输入相机序列号制定左右相机:
|
||||
dict:{'left': '241222075132', 'right': '242322076532', 'head': '242322076532'}
|
||||
保存路径:
|
||||
str:./save
|
||||
输入为空,自动分配相机序列号(不指定左、右、头部),保存路径为./save
|
||||
"""
|
||||
|
||||
# capture = CameraCapture()
|
||||
# capture.get_serial_num()
|
||||
# test_get_serial_num()
|
||||
|
||||
# capture.start_camera()
|
||||
# capture.capture_images()
|
||||
# capture.stop_camera()
|
||||
@@ -0,0 +1,71 @@
|
||||
import pytest
|
||||
import pyrealsense2 as rs
|
||||
from shadow_camera.realsense import RealSenseCamera
|
||||
|
||||
|
||||
class TestRealSenseCamera:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_camera(self):
|
||||
self.camera = RealSenseCamera()
|
||||
|
||||
def test_get_serial_num(self):
|
||||
serial_nums = self.camera.get_serial_num()
|
||||
assert isinstance(serial_nums, dict)
|
||||
assert len(serial_nums) > 0
|
||||
|
||||
def test_start_stop_camera(self):
|
||||
self.camera.start_camera()
|
||||
assert self.camera.camera_on is True
|
||||
self.camera.stop_camera()
|
||||
assert self.camera.camera_on is False
|
||||
|
||||
def test_set_resolution(self):
|
||||
color_resolution = [1280, 720]
|
||||
depth_resolution = [1280, 720]
|
||||
self.camera.set_resolution(color_resolution, depth_resolution)
|
||||
assert self.camera._color_resolution == color_resolution
|
||||
assert self.camera._depth_resolution == depth_resolution
|
||||
|
||||
def test_set_frame_rate(self):
|
||||
color_fps = 60
|
||||
depth_fps = 60
|
||||
self.camera.set_frame_rate(color_fps, depth_fps)
|
||||
assert self.camera._color_frames_rate == color_fps
|
||||
assert self.camera._depth_frames_rate == depth_fps
|
||||
|
||||
def test_read_frame(self):
|
||||
self.camera.start_camera()
|
||||
color_image, depth_image, colorized_depth, point_cloud = (
|
||||
self.camera.read_frame()
|
||||
)
|
||||
assert color_image is not None
|
||||
assert depth_image is not None
|
||||
self.camera.stop_camera()
|
||||
|
||||
def test_read_align_frame(self):
|
||||
self.camera.start_camera()
|
||||
color_image, depth_image, colorized_depth, point_cloud = (
|
||||
self.camera.read_align_frame()
|
||||
)
|
||||
assert color_image is not None
|
||||
assert depth_image is not None
|
||||
self.camera.stop_camera()
|
||||
|
||||
def test_get_camera_intrinsics(self):
|
||||
self.camera.start_camera()
|
||||
color_intrinsics, depth_intrinsics = self.camera.get_camera_intrinsics()
|
||||
assert color_intrinsics is not None
|
||||
assert depth_intrinsics is not None
|
||||
self.camera.stop_camera()
|
||||
|
||||
def test_get_3d_camera_coordinate(self):
|
||||
self.camera.start_camera()
|
||||
# 先调用 read_align_frame 方法以确保 _aligned_depth_frame 被设置
|
||||
self.camera.read_align_frame()
|
||||
depth_pixel = [320, 240]
|
||||
distance, camera_coordinate = self.camera.get_3d_camera_coordinate(
|
||||
depth_pixel, align=True
|
||||
)
|
||||
assert distance > 0
|
||||
assert len(camera_coordinate) == 3
|
||||
self.camera.stop_camera()
|
||||
10
realman_src/realman_aloha/shadow_rm_act/.gitignore
vendored
Normal file
10
realman_src/realman_aloha/shadow_rm_act/.gitignore
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
__pycache__/
|
||||
build/
|
||||
devel/
|
||||
dist/
|
||||
data/
|
||||
.catkin_workspace
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pt
|
||||
.vscode/
|
||||
89
realman_src/realman_aloha/shadow_rm_act/README.md
Normal file
89
realman_src/realman_aloha/shadow_rm_act/README.md
Normal file
@@ -0,0 +1,89 @@
|
||||
# ACT: Action Chunking with Transformers
|
||||
|
||||
### *New*: [ACT tuning tips](https://docs.google.com/document/d/1FVIZfoALXg_ZkYKaYVh-qOlaXveq5CtvJHXkY25eYhs/edit?usp=sharing)
|
||||
TL;DR: if your ACT policy is jerky or pauses in the middle of an episode, just train for longer! Success rate and smoothness can improve way after loss plateaus.
|
||||
|
||||
#### Project Website: https://tonyzhaozh.github.io/aloha/
|
||||
|
||||
This repo contains the implementation of ACT, together with 2 simulated environments:
|
||||
Transfer Cube and Bimanual Insertion. You can train and evaluate ACT in sim or real.
|
||||
For real, you would also need to install [ALOHA](https://github.com/tonyzhaozh/aloha).
|
||||
|
||||
### Updates:
|
||||
You can find all scripted/human demo for simulated environments [here](https://drive.google.com/drive/folders/1gPR03v05S1xiInoVJn7G7VJ9pDCnxq9O?usp=share_link).
|
||||
|
||||
|
||||
### Repo Structure
|
||||
- ``imitate_episodes.py`` Train and Evaluate ACT
|
||||
- ``policy.py`` An adaptor for ACT policy
|
||||
- ``detr`` Model definitions of ACT, modified from DETR
|
||||
- ``sim_env.py`` Mujoco + DM_Control environments with joint space control
|
||||
- ``ee_sim_env.py`` Mujoco + DM_Control environments with EE space control
|
||||
- ``scripted_policy.py`` Scripted policies for sim environments
|
||||
- ``constants.py`` Constants shared across files
|
||||
- ``utils.py`` Utils such as data loading and helper functions
|
||||
- ``visualize_episodes.py`` Save videos from a .hdf5 dataset
|
||||
|
||||
|
||||
### Installation
|
||||
|
||||
conda create -n aloha python=3.8.10
|
||||
conda activate aloha
|
||||
pip install torchvision
|
||||
pip install torch
|
||||
pip install pyquaternion
|
||||
pip install pyyaml
|
||||
pip install rospkg
|
||||
pip install pexpect
|
||||
pip install mujoco==2.3.7
|
||||
pip install dm_control==1.0.14
|
||||
pip install opencv-python
|
||||
pip install matplotlib
|
||||
pip install einops
|
||||
pip install packaging
|
||||
pip install h5py
|
||||
pip install ipython
|
||||
cd act/detr && pip install -e .
|
||||
|
||||
### Example Usages
|
||||
|
||||
To set up a new terminal, run:
|
||||
|
||||
conda activate aloha
|
||||
cd <path to act repo>
|
||||
|
||||
### Simulated experiments
|
||||
|
||||
We use ``sim_transfer_cube_scripted`` task in the examples below. Another option is ``sim_insertion_scripted``.
|
||||
To generated 50 episodes of scripted data, run:
|
||||
|
||||
python3 record_sim_episodes.py \
|
||||
--task_name sim_transfer_cube_scripted \
|
||||
--dataset_dir <data save dir> \
|
||||
--num_episodes 50
|
||||
|
||||
To can add the flag ``--onscreen_render`` to see real-time rendering.
|
||||
To visualize the episode after it is collected, run
|
||||
|
||||
python3 visualize_episodes.py --dataset_dir <data save dir> --episode_idx 0
|
||||
|
||||
To train ACT:
|
||||
|
||||
# Transfer Cube task
|
||||
python3 imitate_episodes.py \
|
||||
--task_name sim_transfer_cube_scripted \
|
||||
--ckpt_dir <ckpt dir> \
|
||||
--policy_class ACT --kl_weight 10 --chunk_size 100 --hidden_dim 512 --batch_size 8 --dim_feedforward 3200 \
|
||||
--num_epochs 2000 --lr 1e-5 \
|
||||
--seed 0
|
||||
|
||||
|
||||
To evaluate the policy, run the same command but add ``--eval``. This loads the best validation checkpoint.
|
||||
The success rate should be around 90% for transfer cube, and around 50% for insertion.
|
||||
To enable temporal ensembling, add flag ``--temporal_agg``.
|
||||
Videos will be saved to ``<ckpt_dir>`` for each rollout.
|
||||
You can also add ``--onscreen_render`` to see real-time rendering during evaluation.
|
||||
|
||||
For real-world data where things can be harder to model, train for at least 5000 epochs or 3-4 times the length after the loss has plateaued.
|
||||
Please refer to [tuning tips](https://docs.google.com/document/d/1FVIZfoALXg_ZkYKaYVh-qOlaXveq5CtvJHXkY25eYhs/edit?usp=sharing) for more info.
|
||||
|
||||
74
realman_src/realman_aloha/shadow_rm_act/config/config.yaml
Normal file
74
realman_src/realman_aloha/shadow_rm_act/config/config.yaml
Normal file
@@ -0,0 +1,74 @@
|
||||
robot_env: {
|
||||
# TODO change the path to the correct one
|
||||
rm_left_arm: '/home/rm/aloha/shadow_rm_aloha/config/rm_left_arm.yaml',
|
||||
rm_right_arm: '/home/rm/aloha/shadow_rm_aloha/config/rm_right_arm.yaml',
|
||||
arm_axis: 6,
|
||||
head_camera: '215222076892',
|
||||
bottom_camera: '215222076981',
|
||||
left_camera: '152122078151',
|
||||
right_camera: '152122073489',
|
||||
# init_left_arm_angle: [0.226, 21.180, 91.304, -0.515, 67.486, 2.374, 0.9],
|
||||
# init_right_arm_angle: [-1.056, 33.057, 84.376, -0.204, 66.357, -3.236, 0.9]
|
||||
init_left_arm_angle: [6.45, 66.093, 2.9, 20.919, -1.491, 100.756, 18.808, 0.617],
|
||||
init_right_arm_angle: [166.953, -33.575, -163.917, 73.3, -9.581, 69.51, 0.876]
|
||||
}
|
||||
dataset_dir: '/home/rm/aloha/shadow_rm_aloha/data/dataset/20250103'
|
||||
checkpoint_dir: '/home/rm/aloha/shadow_rm_act/data'
|
||||
# checkpoint_name: 'policy_best.ckpt'
|
||||
checkpoint_name: 'policy_9500.ckpt'
|
||||
state_dim: 14
|
||||
save_episode: True
|
||||
num_rollouts: 50 #训练期间要收集的 rollout(轨迹)数量
|
||||
real_robot: True
|
||||
policy_class: 'ACT'
|
||||
onscreen_render: False
|
||||
camera_names: ['cam_high', 'cam_low', 'cam_left', 'cam_right']
|
||||
episode_len: 300 #episode 的最大长度(时间步数)。
|
||||
task_name: 'aloha_01_11.28'
|
||||
temporal_agg: False #是否使用时间聚合
|
||||
batch_size: 8 #训练期间每批的样本数。
|
||||
seed: 1000 #随机种子。
|
||||
chunk_size: 30 #用于处理序列的块大小
|
||||
eval_every: 1 #每隔 eval_every 步评估一次模型。
|
||||
num_steps: 10000 #训练的总步数。
|
||||
validate_every: 1 #每隔 validate_every 步验证一次模型。
|
||||
save_every: 500 #每隔 save_every 步保存一次检查点。
|
||||
load_pretrain: False #是否加载预训练模型。
|
||||
resume_ckpt_path:
|
||||
name_filter: # TODO
|
||||
skip_mirrored_data: False #是否跳过镜像数据(例如用于基于对称性的数据增强)。
|
||||
stats_dir:
|
||||
sample_weights:
|
||||
train_ratio: 0.8 #用于训练的数据比例(其余数据用于验证)
|
||||
|
||||
policy_config: {
|
||||
hidden_dim: 512, # Size of the embeddings (dimension of the transformer)
|
||||
state_dim: 14, # Dimension of the state
|
||||
position_embedding: 'sine', # ('sine', 'learned').Type of positional embedding to use on top of the image features
|
||||
lr_backbone: 1.0e-5,
|
||||
masks: False, # If true, the model masks the non-visible pixels
|
||||
backbone: 'resnet18',
|
||||
dilation: False, # If true, we replace stride with dilation in the last convolutional block (DC5)
|
||||
dropout: 0.1, # Dropout applied in the transformer
|
||||
nheads: 8,
|
||||
dim_feedforward: 3200, # Intermediate size of the feedforward layers in the transformer blocks
|
||||
enc_layers: 4, # Number of encoding layers in the transformer
|
||||
dec_layers: 7, # Number of decoding layers in the transformer
|
||||
pre_norm: False, # If true, apply LayerNorm to the input instead of the output of the MultiheadAttention and FeedForward
|
||||
num_queries: 30,
|
||||
camera_names: ['cam_high', 'cam_low', 'cam_left', 'cam_right'],
|
||||
vq: False,
|
||||
vq_class: none,
|
||||
vq_dim: 64,
|
||||
action_dim: 14,
|
||||
no_encoder: False,
|
||||
lr: 1.0e-5,
|
||||
weight_decay: 1.0e-4,
|
||||
kl_weight: 10,
|
||||
|
||||
# lr_drop: 200,
|
||||
# clip_max_norm: 0.1,
|
||||
}
|
||||
|
||||
|
||||
|
||||
267
realman_src/realman_aloha/shadow_rm_act/ee_sim_env.py
Normal file
267
realman_src/realman_aloha/shadow_rm_act/ee_sim_env.py
Normal file
@@ -0,0 +1,267 @@
|
||||
import numpy as np
|
||||
import collections
|
||||
import os
|
||||
|
||||
from constants import DT, XML_DIR, START_ARM_POSE
|
||||
from constants import PUPPET_GRIPPER_POSITION_CLOSE
|
||||
from constants import PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN
|
||||
from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN
|
||||
from constants import PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN
|
||||
|
||||
from src.shadow_act.utils.utils import sample_box_pose, sample_insertion_pose
|
||||
from dm_control import mujoco
|
||||
from dm_control.rl import control
|
||||
from dm_control.suite import base
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
|
||||
def make_ee_sim_env(task_name):
|
||||
"""
|
||||
Environment for simulated robot bi-manual manipulation, with end-effector control.
|
||||
Action space: [left_arm_pose (7), # position and quaternion for end effector
|
||||
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_pose (7), # position and quaternion for end effector
|
||||
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
||||
|
||||
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
||||
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
||||
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
||||
right_arm_qvel (6), # absolute joint velocity (rad)
|
||||
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
||||
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
|
||||
"""
|
||||
if 'sim_transfer_cube' in task_name:
|
||||
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_ee_transfer_cube.xml')
|
||||
physics = mujoco.Physics.from_xml_path(xml_path)
|
||||
task = TransferCubeEETask(random=False)
|
||||
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
|
||||
n_sub_steps=None, flat_observation=False)
|
||||
elif 'sim_insertion' in task_name:
|
||||
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_ee_insertion.xml')
|
||||
physics = mujoco.Physics.from_xml_path(xml_path)
|
||||
task = InsertionEETask(random=False)
|
||||
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
|
||||
n_sub_steps=None, flat_observation=False)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return env
|
||||
|
||||
class BimanualViperXEETask(base.Task):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
|
||||
def before_step(self, action, physics):
|
||||
a_len = len(action) // 2
|
||||
action_left = action[:a_len]
|
||||
action_right = action[a_len:]
|
||||
|
||||
# set mocap position and quat
|
||||
# left
|
||||
np.copyto(physics.data.mocap_pos[0], action_left[:3])
|
||||
np.copyto(physics.data.mocap_quat[0], action_left[3:7])
|
||||
# right
|
||||
np.copyto(physics.data.mocap_pos[1], action_right[:3])
|
||||
np.copyto(physics.data.mocap_quat[1], action_right[3:7])
|
||||
|
||||
# set gripper
|
||||
g_left_ctrl = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(action_left[7])
|
||||
g_right_ctrl = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(action_right[7])
|
||||
np.copyto(physics.data.ctrl, np.array([g_left_ctrl, -g_left_ctrl, g_right_ctrl, -g_right_ctrl]))
|
||||
|
||||
def initialize_robots(self, physics):
|
||||
# reset joint position
|
||||
physics.named.data.qpos[:16] = START_ARM_POSE
|
||||
|
||||
# reset mocap to align with end effector
|
||||
# to obtain these numbers:
|
||||
# (1) make an ee_sim env and reset to the same start_pose
|
||||
# (2) get env._physics.named.data.xpos['vx300s_left/gripper_link']
|
||||
# get env._physics.named.data.xquat['vx300s_left/gripper_link']
|
||||
# repeat the same for right side
|
||||
np.copyto(physics.data.mocap_pos[0], [-0.31718881, 0.5, 0.29525084])
|
||||
np.copyto(physics.data.mocap_quat[0], [1, 0, 0, 0])
|
||||
# right
|
||||
np.copyto(physics.data.mocap_pos[1], np.array([0.31718881, 0.49999888, 0.29525084]))
|
||||
np.copyto(physics.data.mocap_quat[1], [1, 0, 0, 0])
|
||||
|
||||
# reset gripper control
|
||||
close_gripper_control = np.array([
|
||||
PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
-PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
-PUPPET_GRIPPER_POSITION_CLOSE,
|
||||
])
|
||||
np.copyto(physics.data.ctrl, close_gripper_control)
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_qpos(physics):
|
||||
qpos_raw = physics.data.qpos.copy()
|
||||
left_qpos_raw = qpos_raw[:8]
|
||||
right_qpos_raw = qpos_raw[8:16]
|
||||
left_arm_qpos = left_qpos_raw[:6]
|
||||
right_arm_qpos = right_qpos_raw[:6]
|
||||
left_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[6])]
|
||||
right_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[6])]
|
||||
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
||||
|
||||
@staticmethod
|
||||
def get_qvel(physics):
|
||||
qvel_raw = physics.data.qvel.copy()
|
||||
left_qvel_raw = qvel_raw[:8]
|
||||
right_qvel_raw = qvel_raw[8:16]
|
||||
left_arm_qvel = left_qvel_raw[:6]
|
||||
right_arm_qvel = right_qvel_raw[:6]
|
||||
left_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[6])]
|
||||
right_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[6])]
|
||||
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_observation(self, physics):
|
||||
# note: it is important to do .copy()
|
||||
obs = collections.OrderedDict()
|
||||
obs['qpos'] = self.get_qpos(physics)
|
||||
obs['qvel'] = self.get_qvel(physics)
|
||||
obs['env_state'] = self.get_env_state(physics)
|
||||
obs['images'] = dict()
|
||||
obs['images']['top'] = physics.render(height=480, width=640, camera_id='top')
|
||||
obs['images']['angle'] = physics.render(height=480, width=640, camera_id='angle')
|
||||
obs['images']['vis'] = physics.render(height=480, width=640, camera_id='front_close')
|
||||
# used in scripted policy to obtain starting pose
|
||||
obs['mocap_pose_left'] = np.concatenate([physics.data.mocap_pos[0], physics.data.mocap_quat[0]]).copy()
|
||||
obs['mocap_pose_right'] = np.concatenate([physics.data.mocap_pos[1], physics.data.mocap_quat[1]]).copy()
|
||||
|
||||
# used when replaying joint trajectory
|
||||
obs['gripper_ctrl'] = physics.data.ctrl.copy()
|
||||
return obs
|
||||
|
||||
def get_reward(self, physics):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TransferCubeEETask(BimanualViperXEETask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
self.initialize_robots(physics)
|
||||
# randomize box position
|
||||
cube_pose = sample_box_pose()
|
||||
box_start_idx = physics.model.name2id('red_box_joint', 'joint')
|
||||
np.copyto(physics.data.qpos[box_start_idx : box_start_idx + 7], cube_pose)
|
||||
# print(f"randomized cube position to {cube_position}")
|
||||
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether left gripper is holding the box
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, 'geom')
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, 'geom')
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_table = ("red_box", "table") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_right_gripper:
|
||||
reward = 1
|
||||
if touch_right_gripper and not touch_table: # lifted
|
||||
reward = 2
|
||||
if touch_left_gripper: # attempted transfer
|
||||
reward = 3
|
||||
if touch_left_gripper and not touch_table: # successful transfer
|
||||
reward = 4
|
||||
return reward
|
||||
|
||||
|
||||
class InsertionEETask(BimanualViperXEETask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
self.initialize_robots(physics)
|
||||
# randomize peg and socket position
|
||||
peg_pose, socket_pose = sample_insertion_pose()
|
||||
id2index = lambda j_id: 16 + (j_id - 16) * 7 # first 16 is robot qpos, 7 is pose dim # hacky
|
||||
|
||||
peg_start_id = physics.model.name2id('red_peg_joint', 'joint')
|
||||
peg_start_idx = id2index(peg_start_id)
|
||||
np.copyto(physics.data.qpos[peg_start_idx : peg_start_idx + 7], peg_pose)
|
||||
# print(f"randomized cube position to {cube_position}")
|
||||
|
||||
socket_start_id = physics.model.name2id('blue_socket_joint', 'joint')
|
||||
socket_start_idx = id2index(socket_start_id)
|
||||
np.copyto(physics.data.qpos[socket_start_idx : socket_start_idx + 7], socket_pose)
|
||||
# print(f"randomized cube position to {cube_position}")
|
||||
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether peg touches the pin
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, 'geom')
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, 'geom')
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_left_gripper = ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||
("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||
("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||
("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
|
||||
peg_touch_table = ("red_peg", "table") in all_contact_pairs
|
||||
socket_touch_table = ("socket-1", "table") in all_contact_pairs or \
|
||||
("socket-2", "table") in all_contact_pairs or \
|
||||
("socket-3", "table") in all_contact_pairs or \
|
||||
("socket-4", "table") in all_contact_pairs
|
||||
peg_touch_socket = ("red_peg", "socket-1") in all_contact_pairs or \
|
||||
("red_peg", "socket-2") in all_contact_pairs or \
|
||||
("red_peg", "socket-3") in all_contact_pairs or \
|
||||
("red_peg", "socket-4") in all_contact_pairs
|
||||
pin_touched = ("red_peg", "pin") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_left_gripper and touch_right_gripper: # touch both
|
||||
reward = 1
|
||||
if touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table): # grasp both
|
||||
reward = 2
|
||||
if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
|
||||
reward = 3
|
||||
if pin_touched: # successful insertion
|
||||
reward = 4
|
||||
return reward
|
||||
36
realman_src/realman_aloha/shadow_rm_act/pyproject.toml
Normal file
36
realman_src/realman_aloha/shadow_rm_act/pyproject.toml
Normal file
@@ -0,0 +1,36 @@
|
||||
[tool.poetry]
|
||||
name = "shadow_act"
|
||||
version = "0.1.0"
|
||||
description = "Embodied data, ACT and other methods; training and verification function packages"
|
||||
readme = "README.md"
|
||||
authors = ["Shadow <qiuchengzhan@gmail.com>"]
|
||||
license = "MIT"
|
||||
# include = ["realman_vision/pytransform/_pytransform.so",]
|
||||
classifiers = [
|
||||
"Operating System :: POSIX :: Linux amd64",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.9"
|
||||
wandb = ">=0.18.0"
|
||||
einops = ">=0.8.0"
|
||||
|
||||
|
||||
|
||||
[tool.poetry.dev-dependencies] # 列出开发时所需的依赖项,比如测试、文档生成等工具。
|
||||
pytest = ">=8.3"
|
||||
black = ">=24.10.0"
|
||||
|
||||
|
||||
|
||||
[tool.poetry.plugins."scripts"] # 定义命令行脚本,使得用户可以通过命令行运行指定的函数。
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.8.4"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
189
realman_src/realman_aloha/shadow_rm_act/record_sim_episodes.py
Normal file
189
realman_src/realman_aloha/shadow_rm_act/record_sim_episodes.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import time
|
||||
import os
|
||||
import numpy as np
|
||||
import argparse
|
||||
import matplotlib.pyplot as plt
|
||||
import h5py
|
||||
|
||||
from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN, SIM_TASK_CONFIGS
|
||||
from ee_sim_env import make_ee_sim_env
|
||||
from sim_env import make_sim_env, BOX_POSE
|
||||
from scripted_policy import PickAndTransferPolicy, InsertionPolicy
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
|
||||
def main(args):
|
||||
"""
|
||||
Generate demonstration data in simulation.
|
||||
First rollout the policy (defined in ee space) in ee_sim_env. Obtain the joint trajectory.
|
||||
Replace the gripper joint positions with the commanded joint position.
|
||||
Replay this joint trajectory (as action sequence) in sim_env, and record all observations.
|
||||
Save this episode of data, and continue to next episode of data collection.
|
||||
"""
|
||||
|
||||
task_name = args['task_name']
|
||||
dataset_dir = args['dataset_dir']
|
||||
num_episodes = args['num_episodes']
|
||||
onscreen_render = args['onscreen_render']
|
||||
inject_noise = False
|
||||
render_cam_name = 'angle'
|
||||
|
||||
if not os.path.isdir(dataset_dir):
|
||||
os.makedirs(dataset_dir, exist_ok=True)
|
||||
|
||||
episode_len = SIM_TASK_CONFIGS[task_name]['episode_len']
|
||||
camera_names = SIM_TASK_CONFIGS[task_name]['camera_names']
|
||||
if task_name == 'sim_transfer_cube_scripted':
|
||||
policy_cls = PickAndTransferPolicy
|
||||
elif task_name == 'sim_insertion_scripted':
|
||||
policy_cls = InsertionPolicy
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
success = []
|
||||
for episode_idx in range(num_episodes):
|
||||
print(f'{episode_idx=}')
|
||||
print('Rollout out EE space scripted policy')
|
||||
# setup the environment
|
||||
env = make_ee_sim_env(task_name)
|
||||
ts = env.reset()
|
||||
episode = [ts]
|
||||
policy = policy_cls(inject_noise)
|
||||
# setup plotting
|
||||
if onscreen_render:
|
||||
ax = plt.subplot()
|
||||
plt_img = ax.imshow(ts.observation['images'][render_cam_name])
|
||||
plt.ion()
|
||||
for step in range(episode_len):
|
||||
action = policy(ts)
|
||||
ts = env.step(action)
|
||||
episode.append(ts)
|
||||
if onscreen_render:
|
||||
plt_img.set_data(ts.observation['images'][render_cam_name])
|
||||
plt.pause(0.002)
|
||||
plt.close()
|
||||
|
||||
episode_return = np.sum([ts.reward for ts in episode[1:]])
|
||||
episode_max_reward = np.max([ts.reward for ts in episode[1:]])
|
||||
if episode_max_reward == env.task.max_reward:
|
||||
print(f"{episode_idx=} Successful, {episode_return=}")
|
||||
else:
|
||||
print(f"{episode_idx=} Failed")
|
||||
|
||||
joint_traj = [ts.observation['qpos'] for ts in episode]
|
||||
# replace gripper pose with gripper control
|
||||
gripper_ctrl_traj = [ts.observation['gripper_ctrl'] for ts in episode]
|
||||
for joint, ctrl in zip(joint_traj, gripper_ctrl_traj):
|
||||
left_ctrl = PUPPET_GRIPPER_POSITION_NORMALIZE_FN(ctrl[0])
|
||||
right_ctrl = PUPPET_GRIPPER_POSITION_NORMALIZE_FN(ctrl[2])
|
||||
joint[6] = left_ctrl
|
||||
joint[6+7] = right_ctrl
|
||||
|
||||
subtask_info = episode[0].observation['env_state'].copy() # box pose at step 0
|
||||
|
||||
# clear unused variables
|
||||
del env
|
||||
del episode
|
||||
del policy
|
||||
|
||||
# setup the environment
|
||||
print('Replaying joint commands')
|
||||
env = make_sim_env(task_name)
|
||||
BOX_POSE[0] = subtask_info # make sure the sim_env has the same object configurations as ee_sim_env
|
||||
ts = env.reset()
|
||||
|
||||
episode_replay = [ts]
|
||||
# setup plotting
|
||||
if onscreen_render:
|
||||
ax = plt.subplot()
|
||||
plt_img = ax.imshow(ts.observation['images'][render_cam_name])
|
||||
plt.ion()
|
||||
for t in range(len(joint_traj)): # note: this will increase episode length by 1
|
||||
action = joint_traj[t]
|
||||
ts = env.step(action)
|
||||
episode_replay.append(ts)
|
||||
if onscreen_render:
|
||||
plt_img.set_data(ts.observation['images'][render_cam_name])
|
||||
plt.pause(0.02)
|
||||
|
||||
episode_return = np.sum([ts.reward for ts in episode_replay[1:]])
|
||||
episode_max_reward = np.max([ts.reward for ts in episode_replay[1:]])
|
||||
if episode_max_reward == env.task.max_reward:
|
||||
success.append(1)
|
||||
print(f"{episode_idx=} Successful, {episode_return=}")
|
||||
else:
|
||||
success.append(0)
|
||||
print(f"{episode_idx=} Failed")
|
||||
|
||||
plt.close()
|
||||
|
||||
"""
|
||||
For each timestep:
|
||||
observations
|
||||
- images
|
||||
- each_cam_name (480, 640, 3) 'uint8'
|
||||
- qpos (14,) 'float64'
|
||||
- qvel (14,) 'float64'
|
||||
|
||||
action (14,) 'float64'
|
||||
"""
|
||||
|
||||
data_dict = {
|
||||
'/observations/qpos': [],
|
||||
'/observations/qvel': [],
|
||||
'/action': [],
|
||||
}
|
||||
for cam_name in camera_names:
|
||||
data_dict[f'/observations/images/{cam_name}'] = []
|
||||
|
||||
# because the replaying, there will be eps_len + 1 actions and eps_len + 2 timesteps
|
||||
# truncate here to be consistent
|
||||
joint_traj = joint_traj[:-1]
|
||||
episode_replay = episode_replay[:-1]
|
||||
|
||||
# len(joint_traj) i.e. actions: max_timesteps
|
||||
# len(episode_replay) i.e. time steps: max_timesteps + 1
|
||||
max_timesteps = len(joint_traj)
|
||||
while joint_traj:
|
||||
action = joint_traj.pop(0)
|
||||
ts = episode_replay.pop(0)
|
||||
data_dict['/observations/qpos'].append(ts.observation['qpos'])
|
||||
data_dict['/observations/qvel'].append(ts.observation['qvel'])
|
||||
data_dict['/action'].append(action)
|
||||
for cam_name in camera_names:
|
||||
data_dict[f'/observations/images/{cam_name}'].append(ts.observation['images'][cam_name])
|
||||
|
||||
# HDF5
|
||||
t0 = time.time()
|
||||
dataset_path = os.path.join(dataset_dir, f'episode_{episode_idx}')
|
||||
with h5py.File(dataset_path + '.hdf5', 'w', rdcc_nbytes=1024 ** 2 * 2) as root:
|
||||
root.attrs['sim'] = True
|
||||
obs = root.create_group('observations')
|
||||
image = obs.create_group('images')
|
||||
for cam_name in camera_names:
|
||||
_ = image.create_dataset(cam_name, (max_timesteps, 480, 640, 3), dtype='uint8',
|
||||
chunks=(1, 480, 640, 3), )
|
||||
# compression='gzip',compression_opts=2,)
|
||||
# compression=32001, compression_opts=(0, 0, 0, 0, 9, 1, 1), shuffle=False)
|
||||
qpos = obs.create_dataset('qpos', (max_timesteps, 14))
|
||||
qvel = obs.create_dataset('qvel', (max_timesteps, 14))
|
||||
action = root.create_dataset('action', (max_timesteps, 14))
|
||||
|
||||
for name, array in data_dict.items():
|
||||
root[name][...] = array
|
||||
print(f'Saving: {time.time() - t0:.1f} secs\n')
|
||||
|
||||
print(f'Saved to {dataset_dir}')
|
||||
print(f'Success: {np.sum(success)} / {len(success)}')
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True)
|
||||
parser.add_argument('--dataset_dir', action='store', type=str, help='dataset saving dir', required=True)
|
||||
parser.add_argument('--num_episodes', action='store', type=int, help='num_episodes', required=False)
|
||||
parser.add_argument('--onscreen_render', action='store_true')
|
||||
|
||||
main(vars(parser.parse_args()))
|
||||
|
||||
194
realman_src/realman_aloha/shadow_rm_act/scripted_policy.py
Normal file
194
realman_src/realman_aloha/shadow_rm_act/scripted_policy.py
Normal file
@@ -0,0 +1,194 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from pyquaternion import Quaternion
|
||||
|
||||
from constants import SIM_TASK_CONFIGS
|
||||
from ee_sim_env import make_ee_sim_env
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
|
||||
class BasePolicy:
|
||||
def __init__(self, inject_noise=False):
|
||||
self.inject_noise = inject_noise
|
||||
self.step_count = 0
|
||||
self.left_trajectory = None
|
||||
self.right_trajectory = None
|
||||
|
||||
def generate_trajectory(self, ts_first):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def interpolate(curr_waypoint, next_waypoint, t):
|
||||
t_frac = (t - curr_waypoint["t"]) / (next_waypoint["t"] - curr_waypoint["t"])
|
||||
curr_xyz = curr_waypoint['xyz']
|
||||
curr_quat = curr_waypoint['quat']
|
||||
curr_grip = curr_waypoint['gripper']
|
||||
next_xyz = next_waypoint['xyz']
|
||||
next_quat = next_waypoint['quat']
|
||||
next_grip = next_waypoint['gripper']
|
||||
xyz = curr_xyz + (next_xyz - curr_xyz) * t_frac
|
||||
quat = curr_quat + (next_quat - curr_quat) * t_frac
|
||||
gripper = curr_grip + (next_grip - curr_grip) * t_frac
|
||||
return xyz, quat, gripper
|
||||
|
||||
def __call__(self, ts):
|
||||
# generate trajectory at first timestep, then open-loop execution
|
||||
if self.step_count == 0:
|
||||
self.generate_trajectory(ts)
|
||||
|
||||
# obtain left and right waypoints
|
||||
if self.left_trajectory[0]['t'] == self.step_count:
|
||||
self.curr_left_waypoint = self.left_trajectory.pop(0)
|
||||
next_left_waypoint = self.left_trajectory[0]
|
||||
|
||||
if self.right_trajectory[0]['t'] == self.step_count:
|
||||
self.curr_right_waypoint = self.right_trajectory.pop(0)
|
||||
next_right_waypoint = self.right_trajectory[0]
|
||||
|
||||
# interpolate between waypoints to obtain current pose and gripper command
|
||||
left_xyz, left_quat, left_gripper = self.interpolate(self.curr_left_waypoint, next_left_waypoint, self.step_count)
|
||||
right_xyz, right_quat, right_gripper = self.interpolate(self.curr_right_waypoint, next_right_waypoint, self.step_count)
|
||||
|
||||
# Inject noise
|
||||
if self.inject_noise:
|
||||
scale = 0.01
|
||||
left_xyz = left_xyz + np.random.uniform(-scale, scale, left_xyz.shape)
|
||||
right_xyz = right_xyz + np.random.uniform(-scale, scale, right_xyz.shape)
|
||||
|
||||
action_left = np.concatenate([left_xyz, left_quat, [left_gripper]])
|
||||
action_right = np.concatenate([right_xyz, right_quat, [right_gripper]])
|
||||
|
||||
self.step_count += 1
|
||||
return np.concatenate([action_left, action_right])
|
||||
|
||||
|
||||
class PickAndTransferPolicy(BasePolicy):
|
||||
|
||||
def generate_trajectory(self, ts_first):
|
||||
init_mocap_pose_right = ts_first.observation['mocap_pose_right']
|
||||
init_mocap_pose_left = ts_first.observation['mocap_pose_left']
|
||||
|
||||
box_info = np.array(ts_first.observation['env_state'])
|
||||
box_xyz = box_info[:3]
|
||||
box_quat = box_info[3:]
|
||||
# print(f"Generate trajectory for {box_xyz=}")
|
||||
|
||||
gripper_pick_quat = Quaternion(init_mocap_pose_right[3:])
|
||||
gripper_pick_quat = gripper_pick_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=-60)
|
||||
|
||||
meet_left_quat = Quaternion(axis=[1.0, 0.0, 0.0], degrees=90)
|
||||
|
||||
meet_xyz = np.array([0, 0.5, 0.25])
|
||||
|
||||
self.left_trajectory = [
|
||||
{"t": 0, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": 0}, # sleep
|
||||
{"t": 100, "xyz": meet_xyz + np.array([-0.1, 0, -0.02]), "quat": meet_left_quat.elements, "gripper": 1}, # approach meet position
|
||||
{"t": 260, "xyz": meet_xyz + np.array([0.02, 0, -0.02]), "quat": meet_left_quat.elements, "gripper": 1}, # move to meet position
|
||||
{"t": 310, "xyz": meet_xyz + np.array([0.02, 0, -0.02]), "quat": meet_left_quat.elements, "gripper": 0}, # close gripper
|
||||
{"t": 360, "xyz": meet_xyz + np.array([-0.1, 0, -0.02]), "quat": np.array([1, 0, 0, 0]), "gripper": 0}, # move left
|
||||
{"t": 400, "xyz": meet_xyz + np.array([-0.1, 0, -0.02]), "quat": np.array([1, 0, 0, 0]), "gripper": 0}, # stay
|
||||
]
|
||||
|
||||
self.right_trajectory = [
|
||||
{"t": 0, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": 0}, # sleep
|
||||
{"t": 90, "xyz": box_xyz + np.array([0, 0, 0.08]), "quat": gripper_pick_quat.elements, "gripper": 1}, # approach the cube
|
||||
{"t": 130, "xyz": box_xyz + np.array([0, 0, -0.015]), "quat": gripper_pick_quat.elements, "gripper": 1}, # go down
|
||||
{"t": 170, "xyz": box_xyz + np.array([0, 0, -0.015]), "quat": gripper_pick_quat.elements, "gripper": 0}, # close gripper
|
||||
{"t": 200, "xyz": meet_xyz + np.array([0.05, 0, 0]), "quat": gripper_pick_quat.elements, "gripper": 0}, # approach meet position
|
||||
{"t": 220, "xyz": meet_xyz, "quat": gripper_pick_quat.elements, "gripper": 0}, # move to meet position
|
||||
{"t": 310, "xyz": meet_xyz, "quat": gripper_pick_quat.elements, "gripper": 1}, # open gripper
|
||||
{"t": 360, "xyz": meet_xyz + np.array([0.1, 0, 0]), "quat": gripper_pick_quat.elements, "gripper": 1}, # move to right
|
||||
{"t": 400, "xyz": meet_xyz + np.array([0.1, 0, 0]), "quat": gripper_pick_quat.elements, "gripper": 1}, # stay
|
||||
]
|
||||
|
||||
|
||||
class InsertionPolicy(BasePolicy):
|
||||
|
||||
def generate_trajectory(self, ts_first):
|
||||
init_mocap_pose_right = ts_first.observation['mocap_pose_right']
|
||||
init_mocap_pose_left = ts_first.observation['mocap_pose_left']
|
||||
|
||||
peg_info = np.array(ts_first.observation['env_state'])[:7]
|
||||
peg_xyz = peg_info[:3]
|
||||
peg_quat = peg_info[3:]
|
||||
|
||||
socket_info = np.array(ts_first.observation['env_state'])[7:]
|
||||
socket_xyz = socket_info[:3]
|
||||
socket_quat = socket_info[3:]
|
||||
|
||||
gripper_pick_quat_right = Quaternion(init_mocap_pose_right[3:])
|
||||
gripper_pick_quat_right = gripper_pick_quat_right * Quaternion(axis=[0.0, 1.0, 0.0], degrees=-60)
|
||||
|
||||
gripper_pick_quat_left = Quaternion(init_mocap_pose_right[3:])
|
||||
gripper_pick_quat_left = gripper_pick_quat_left * Quaternion(axis=[0.0, 1.0, 0.0], degrees=60)
|
||||
|
||||
meet_xyz = np.array([0, 0.5, 0.15])
|
||||
lift_right = 0.00715
|
||||
|
||||
self.left_trajectory = [
|
||||
{"t": 0, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": 0}, # sleep
|
||||
{"t": 120, "xyz": socket_xyz + np.array([0, 0, 0.08]), "quat": gripper_pick_quat_left.elements, "gripper": 1}, # approach the cube
|
||||
{"t": 170, "xyz": socket_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_left.elements, "gripper": 1}, # go down
|
||||
{"t": 220, "xyz": socket_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_left.elements, "gripper": 0}, # close gripper
|
||||
{"t": 285, "xyz": meet_xyz + np.array([-0.1, 0, 0]), "quat": gripper_pick_quat_left.elements, "gripper": 0}, # approach meet position
|
||||
{"t": 340, "xyz": meet_xyz + np.array([-0.05, 0, 0]), "quat": gripper_pick_quat_left.elements,"gripper": 0}, # insertion
|
||||
{"t": 400, "xyz": meet_xyz + np.array([-0.05, 0, 0]), "quat": gripper_pick_quat_left.elements, "gripper": 0}, # insertion
|
||||
]
|
||||
|
||||
self.right_trajectory = [
|
||||
{"t": 0, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": 0}, # sleep
|
||||
{"t": 120, "xyz": peg_xyz + np.array([0, 0, 0.08]), "quat": gripper_pick_quat_right.elements, "gripper": 1}, # approach the cube
|
||||
{"t": 170, "xyz": peg_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_right.elements, "gripper": 1}, # go down
|
||||
{"t": 220, "xyz": peg_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # close gripper
|
||||
{"t": 285, "xyz": meet_xyz + np.array([0.1, 0, lift_right]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # approach meet position
|
||||
{"t": 340, "xyz": meet_xyz + np.array([0.05, 0, lift_right]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # insertion
|
||||
{"t": 400, "xyz": meet_xyz + np.array([0.05, 0, lift_right]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # insertion
|
||||
|
||||
]
|
||||
|
||||
|
||||
def test_policy(task_name):
|
||||
# example rolling out pick_and_transfer policy
|
||||
onscreen_render = True
|
||||
inject_noise = False
|
||||
|
||||
# setup the environment
|
||||
episode_len = SIM_TASK_CONFIGS[task_name]['episode_len']
|
||||
if 'sim_transfer_cube' in task_name:
|
||||
env = make_ee_sim_env('sim_transfer_cube')
|
||||
elif 'sim_insertion' in task_name:
|
||||
env = make_ee_sim_env('sim_insertion')
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
for episode_idx in range(2):
|
||||
ts = env.reset()
|
||||
episode = [ts]
|
||||
if onscreen_render:
|
||||
ax = plt.subplot()
|
||||
plt_img = ax.imshow(ts.observation['images']['angle'])
|
||||
plt.ion()
|
||||
|
||||
policy = PickAndTransferPolicy(inject_noise)
|
||||
for step in range(episode_len):
|
||||
action = policy(ts)
|
||||
ts = env.step(action)
|
||||
episode.append(ts)
|
||||
if onscreen_render:
|
||||
plt_img.set_data(ts.observation['images']['angle'])
|
||||
plt.pause(0.02)
|
||||
plt.close()
|
||||
|
||||
episode_return = np.sum([ts.reward for ts in episode[1:]])
|
||||
if episode_return > 0:
|
||||
print(f"{episode_idx=} Successful, {episode_return=}")
|
||||
else:
|
||||
print(f"{episode_idx=} Failed")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_task_name = 'sim_transfer_cube_scripted'
|
||||
test_policy(test_task_name)
|
||||
|
||||
278
realman_src/realman_aloha/shadow_rm_act/sim_env.py
Normal file
278
realman_src/realman_aloha/shadow_rm_act/sim_env.py
Normal file
@@ -0,0 +1,278 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import collections
|
||||
import matplotlib.pyplot as plt
|
||||
from dm_control import mujoco
|
||||
from dm_control.rl import control
|
||||
from dm_control.suite import base
|
||||
|
||||
from constants import DT, XML_DIR, START_ARM_POSE
|
||||
from constants import PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN
|
||||
from constants import MASTER_GRIPPER_POSITION_NORMALIZE_FN
|
||||
from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN
|
||||
from constants import PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
BOX_POSE = [None] # to be changed from outside
|
||||
|
||||
def make_sim_env(task_name):
|
||||
"""
|
||||
Environment for simulated robot bi-manual manipulation, with joint position control
|
||||
Action space: [left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
||||
|
||||
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
||||
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
||||
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
||||
right_arm_qvel (6), # absolute joint velocity (rad)
|
||||
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
||||
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
|
||||
"""
|
||||
if 'sim_transfer_cube' in task_name:
|
||||
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_transfer_cube.xml')
|
||||
physics = mujoco.Physics.from_xml_path(xml_path)
|
||||
task = TransferCubeTask(random=False)
|
||||
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
|
||||
n_sub_steps=None, flat_observation=False)
|
||||
elif 'sim_insertion' in task_name:
|
||||
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_insertion.xml')
|
||||
physics = mujoco.Physics.from_xml_path(xml_path)
|
||||
task = InsertionTask(random=False)
|
||||
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
|
||||
n_sub_steps=None, flat_observation=False)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return env
|
||||
|
||||
class BimanualViperXTask(base.Task):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
|
||||
def before_step(self, action, physics):
|
||||
left_arm_action = action[:6]
|
||||
right_arm_action = action[7:7+6]
|
||||
normalized_left_gripper_action = action[6]
|
||||
normalized_right_gripper_action = action[7+6]
|
||||
|
||||
left_gripper_action = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(normalized_left_gripper_action)
|
||||
right_gripper_action = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(normalized_right_gripper_action)
|
||||
|
||||
full_left_gripper_action = [left_gripper_action, -left_gripper_action]
|
||||
full_right_gripper_action = [right_gripper_action, -right_gripper_action]
|
||||
|
||||
env_action = np.concatenate([left_arm_action, full_left_gripper_action, right_arm_action, full_right_gripper_action])
|
||||
super().before_step(env_action, physics)
|
||||
return
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_qpos(physics):
|
||||
qpos_raw = physics.data.qpos.copy()
|
||||
left_qpos_raw = qpos_raw[:8]
|
||||
right_qpos_raw = qpos_raw[8:16]
|
||||
left_arm_qpos = left_qpos_raw[:6]
|
||||
right_arm_qpos = right_qpos_raw[:6]
|
||||
left_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[6])]
|
||||
right_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[6])]
|
||||
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
||||
|
||||
@staticmethod
|
||||
def get_qvel(physics):
|
||||
qvel_raw = physics.data.qvel.copy()
|
||||
left_qvel_raw = qvel_raw[:8]
|
||||
right_qvel_raw = qvel_raw[8:16]
|
||||
left_arm_qvel = left_qvel_raw[:6]
|
||||
right_arm_qvel = right_qvel_raw[:6]
|
||||
left_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[6])]
|
||||
right_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[6])]
|
||||
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_observation(self, physics):
|
||||
obs = collections.OrderedDict()
|
||||
obs['qpos'] = self.get_qpos(physics)
|
||||
obs['qvel'] = self.get_qvel(physics)
|
||||
obs['env_state'] = self.get_env_state(physics)
|
||||
obs['images'] = dict()
|
||||
obs['images']['top'] = physics.render(height=480, width=640, camera_id='top')
|
||||
obs['images']['angle'] = physics.render(height=480, width=640, camera_id='angle')
|
||||
obs['images']['vis'] = physics.render(height=480, width=640, camera_id='front_close')
|
||||
|
||||
return obs
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether left gripper is holding the box
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TransferCubeTask(BimanualViperXTask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
# TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
|
||||
# reset qpos, control and box position
|
||||
with physics.reset_context():
|
||||
physics.named.data.qpos[:16] = START_ARM_POSE
|
||||
np.copyto(physics.data.ctrl, START_ARM_POSE)
|
||||
assert BOX_POSE[0] is not None
|
||||
physics.named.data.qpos[-7:] = BOX_POSE[0]
|
||||
# print(f"{BOX_POSE=}")
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether left gripper is holding the box
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, 'geom')
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, 'geom')
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_table = ("red_box", "table") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_right_gripper:
|
||||
reward = 1
|
||||
if touch_right_gripper and not touch_table: # lifted
|
||||
reward = 2
|
||||
if touch_left_gripper: # attempted transfer
|
||||
reward = 3
|
||||
if touch_left_gripper and not touch_table: # successful transfer
|
||||
reward = 4
|
||||
return reward
|
||||
|
||||
|
||||
class InsertionTask(BimanualViperXTask):
|
||||
def __init__(self, random=None):
|
||||
super().__init__(random=random)
|
||||
self.max_reward = 4
|
||||
|
||||
def initialize_episode(self, physics):
|
||||
"""Sets the state of the environment at the start of each episode."""
|
||||
# TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
|
||||
# reset qpos, control and box position
|
||||
with physics.reset_context():
|
||||
physics.named.data.qpos[:16] = START_ARM_POSE
|
||||
np.copyto(physics.data.ctrl, START_ARM_POSE)
|
||||
assert BOX_POSE[0] is not None
|
||||
physics.named.data.qpos[-7*2:] = BOX_POSE[0] # two objects
|
||||
# print(f"{BOX_POSE=}")
|
||||
super().initialize_episode(physics)
|
||||
|
||||
@staticmethod
|
||||
def get_env_state(physics):
|
||||
env_state = physics.data.qpos.copy()[16:]
|
||||
return env_state
|
||||
|
||||
def get_reward(self, physics):
|
||||
# return whether peg touches the pin
|
||||
all_contact_pairs = []
|
||||
for i_contact in range(physics.data.ncon):
|
||||
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||
name_geom_1 = physics.model.id2name(id_geom_1, 'geom')
|
||||
name_geom_2 = physics.model.id2name(id_geom_2, 'geom')
|
||||
contact_pair = (name_geom_1, name_geom_2)
|
||||
all_contact_pairs.append(contact_pair)
|
||||
|
||||
touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||
touch_left_gripper = ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||
("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||
("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||
("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||
|
||||
peg_touch_table = ("red_peg", "table") in all_contact_pairs
|
||||
socket_touch_table = ("socket-1", "table") in all_contact_pairs or \
|
||||
("socket-2", "table") in all_contact_pairs or \
|
||||
("socket-3", "table") in all_contact_pairs or \
|
||||
("socket-4", "table") in all_contact_pairs
|
||||
peg_touch_socket = ("red_peg", "socket-1") in all_contact_pairs or \
|
||||
("red_peg", "socket-2") in all_contact_pairs or \
|
||||
("red_peg", "socket-3") in all_contact_pairs or \
|
||||
("red_peg", "socket-4") in all_contact_pairs
|
||||
pin_touched = ("red_peg", "pin") in all_contact_pairs
|
||||
|
||||
reward = 0
|
||||
if touch_left_gripper and touch_right_gripper: # touch both
|
||||
reward = 1
|
||||
if touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table): # grasp both
|
||||
reward = 2
|
||||
if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
|
||||
reward = 3
|
||||
if pin_touched: # successful insertion
|
||||
reward = 4
|
||||
return reward
|
||||
|
||||
|
||||
def get_action(master_bot_left, master_bot_right):
|
||||
action = np.zeros(14)
|
||||
# arm action
|
||||
action[:6] = master_bot_left.dxl.joint_states.position[:6]
|
||||
action[7:7+6] = master_bot_right.dxl.joint_states.position[:6]
|
||||
# gripper action
|
||||
left_gripper_pos = master_bot_left.dxl.joint_states.position[7]
|
||||
right_gripper_pos = master_bot_right.dxl.joint_states.position[7]
|
||||
normalized_left_pos = MASTER_GRIPPER_POSITION_NORMALIZE_FN(left_gripper_pos)
|
||||
normalized_right_pos = MASTER_GRIPPER_POSITION_NORMALIZE_FN(right_gripper_pos)
|
||||
action[6] = normalized_left_pos
|
||||
action[7+6] = normalized_right_pos
|
||||
return action
|
||||
|
||||
def test_sim_teleop():
|
||||
""" Testing teleoperation in sim with ALOHA. Requires hardware and ALOHA repo to work. """
|
||||
from interbotix_xs_modules.arm import InterbotixManipulatorXS
|
||||
|
||||
BOX_POSE[0] = [0.2, 0.5, 0.05, 1, 0, 0, 0]
|
||||
|
||||
# source of data
|
||||
master_bot_left = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper",
|
||||
robot_name=f'master_left', init_node=True)
|
||||
master_bot_right = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper",
|
||||
robot_name=f'master_right', init_node=False)
|
||||
|
||||
# setup the environment
|
||||
env = make_sim_env('sim_transfer_cube')
|
||||
ts = env.reset()
|
||||
episode = [ts]
|
||||
# setup plotting
|
||||
ax = plt.subplot()
|
||||
plt_img = ax.imshow(ts.observation['images']['angle'])
|
||||
plt.ion()
|
||||
|
||||
for t in range(1000):
|
||||
action = get_action(master_bot_left, master_bot_right)
|
||||
ts = env.step(action)
|
||||
episode.append(ts)
|
||||
|
||||
plt_img.set_data(ts.observation['images']['angle'])
|
||||
plt.pause(0.02)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sim_teleop()
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = '0.1.0'
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = '0.1.0'
|
||||
@@ -0,0 +1,575 @@
|
||||
import os
|
||||
import time
|
||||
import yaml
|
||||
import torch
|
||||
import pickle
|
||||
import dm_env
|
||||
import logging
|
||||
import collections
|
||||
import numpy as np
|
||||
import tracemalloc
|
||||
from einops import rearrange
|
||||
import matplotlib.pyplot as plt
|
||||
from torchvision import transforms
|
||||
from shadow_rm_robot.realman_arm import RmArm
|
||||
from shadow_camera.realsense import RealSenseCamera
|
||||
from shadow_act.models.latent_model import Latent_Model_Transformer
|
||||
from shadow_act.network.policy import ACTPolicy, CNNMLPPolicy, DiffusionPolicy
|
||||
from shadow_act.utils.utils import set_seed
|
||||
|
||||
|
||||
# 配置logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
# # 隐藏h5py的警告Creating converter from 7 to 5
|
||||
# logging.getLogger("h5py").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class RmActEvaluator:
|
||||
def __init__(self, config, save_episode=True, num_rollouts=50):
|
||||
"""
|
||||
初始化Evaluator类
|
||||
|
||||
Args:
|
||||
config (dict): 配置字典
|
||||
checkpoint_name (str): 检查点名称
|
||||
save_episode (bool): 是否保存每个episode
|
||||
num_rollouts (int): 滚动次数
|
||||
"""
|
||||
self.config = config
|
||||
self._seed = config["seed"]
|
||||
self.robot_env = config["robot_env"]
|
||||
self.checkpoint_dir = config["checkpoint_dir"]
|
||||
self.checkpoint_name = config["checkpoint_name"]
|
||||
self.save_episode = save_episode
|
||||
self.num_rollouts = num_rollouts
|
||||
self.state_dim = config["state_dim"]
|
||||
self.real_robot = config["real_robot"]
|
||||
self.policy_class = config["policy_class"]
|
||||
self.onscreen_render = config["onscreen_render"]
|
||||
self.camera_names = config["camera_names"]
|
||||
self.max_timesteps = config["episode_len"]
|
||||
self.task_name = config["task_name"]
|
||||
self.temporal_agg = config["temporal_agg"]
|
||||
self.onscreen_cam = "angle"
|
||||
self.policy_config = config["policy_config"]
|
||||
self.vq = config["policy_config"]["vq"]
|
||||
# self.actuator_config = config["actuator_config"]
|
||||
# self.use_actuator_net = self.actuator_config["actuator_network_dir"] is not None
|
||||
self.stats = None
|
||||
self.env = None
|
||||
self.env_max_reward = 0
|
||||
|
||||
def _make_policy(self, policy_class, policy_config):
|
||||
"""
|
||||
根据策略类和配置创建策略对象
|
||||
|
||||
Args:
|
||||
policy_class (str): 策略类名称
|
||||
policy_config (dict): 策略配置字典
|
||||
|
||||
Returns:
|
||||
policy: 创建的策略对象
|
||||
"""
|
||||
if policy_class == "ACT":
|
||||
return ACTPolicy(policy_config)
|
||||
elif policy_class == "CNNMLP":
|
||||
return CNNMLPPolicy(policy_config)
|
||||
elif policy_class == "Diffusion":
|
||||
return DiffusionPolicy(policy_config)
|
||||
else:
|
||||
raise NotImplementedError(f"Policy class {policy_class} is not implemented")
|
||||
|
||||
def load_policy_and_stats(self):
|
||||
"""
|
||||
加载策略和统计数据
|
||||
"""
|
||||
checkpoint_path = os.path.join(self.checkpoint_dir, self.checkpoint_name)
|
||||
logging.info(f"Loading policy from: {checkpoint_path}")
|
||||
self.policy = self._make_policy(self.policy_class, self.policy_config)
|
||||
# 加载模型并设置为评估模式
|
||||
self.policy.load_state_dict(torch.load(checkpoint_path, weights_only=True))
|
||||
self.policy.cuda()
|
||||
self.policy.eval()
|
||||
|
||||
if self.vq:
|
||||
vq_dim = self.config["policy_config"]["vq_dim"]
|
||||
vq_class = self.config["policy_config"]["vq_class"]
|
||||
self.latent_model = Latent_Model_Transformer(vq_dim, vq_dim, vq_class)
|
||||
latent_model_checkpoint_path = os.path.join(
|
||||
self.checkpoint_dir, "latent_model_last.ckpt"
|
||||
)
|
||||
self.latent_model.deserialize(torch.load(latent_model_checkpoint_path))
|
||||
self.latent_model.eval()
|
||||
self.latent_model.cuda()
|
||||
logging.info(
|
||||
f"Loaded policy from: {checkpoint_path}, latent model from: {latent_model_checkpoint_path}"
|
||||
)
|
||||
else:
|
||||
logging.info(f"Loaded: {checkpoint_path}")
|
||||
|
||||
stats_path = os.path.join(self.checkpoint_dir, "dataset_stats.pkl")
|
||||
with open(stats_path, "rb") as f:
|
||||
self.stats = pickle.load(f)
|
||||
|
||||
def pre_process(self, state_qpos):
|
||||
"""
|
||||
预处理状态位置
|
||||
|
||||
Args:
|
||||
state_qpos (np.array): 状态位置数组
|
||||
|
||||
Returns:
|
||||
np.array: 预处理后的状态位置
|
||||
"""
|
||||
if self.policy_class == "Diffusion":
|
||||
return ((state_qpos + 1) / 2) * (
|
||||
self.stats["action_max"] - self.stats["action_min"]
|
||||
) + self.stats["action_min"]
|
||||
# 标准化处理,均值为 0,标准差为 1
|
||||
|
||||
return (state_qpos - self.stats["qpos_mean"]) / self.stats["qpos_std"]
|
||||
|
||||
def post_process(self, action):
|
||||
"""
|
||||
后处理动作
|
||||
|
||||
Args:
|
||||
action (np.array): 动作数组
|
||||
|
||||
Returns:
|
||||
np.array: 后处理后的动作
|
||||
"""
|
||||
# 反标准化处理
|
||||
return action * self.stats["action_std"] + self.stats["action_mean"]
|
||||
|
||||
def get_image_torch(self, timestep, camera_names, random_crop_resize=False):
|
||||
"""
|
||||
获取图像
|
||||
|
||||
Args:
|
||||
timestep (object): 时间步对象
|
||||
camera_names (list): 相机名称列表
|
||||
random_crop_resize (bool): 是否随机裁剪和调整大小
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 处理后的图像,归一化(num_cameras, channels, height, width)
|
||||
"""
|
||||
current_images = []
|
||||
for cam_name in camera_names:
|
||||
current_image = rearrange(
|
||||
timestep.observation["images"][cam_name], "h w c -> c h w"
|
||||
)
|
||||
current_images.append(current_image)
|
||||
current_image = np.stack(current_images, axis=0)
|
||||
current_image = (
|
||||
torch.from_numpy(current_image / 255.0).float().cuda().unsqueeze(0)
|
||||
)
|
||||
|
||||
if random_crop_resize:
|
||||
logging.info("Random crop resize is used!")
|
||||
original_size = current_image.shape[-2:]
|
||||
ratio = 0.95
|
||||
current_image = current_image[
|
||||
...,
|
||||
int(original_size[0] * (1 - ratio) / 2) : int(
|
||||
original_size[0] * (1 + ratio) / 2
|
||||
),
|
||||
int(original_size[1] * (1 - ratio) / 2) : int(
|
||||
original_size[1] * (1 + ratio) / 2
|
||||
),
|
||||
]
|
||||
current_image = current_image.squeeze(0)
|
||||
resize_transform = transforms.Resize(original_size, antialias=True)
|
||||
current_image = resize_transform(current_image)
|
||||
current_image = current_image.unsqueeze(0)
|
||||
|
||||
return current_image
|
||||
|
||||
def load_environment(self):
|
||||
"""
|
||||
加载环境
|
||||
"""
|
||||
if self.real_robot:
|
||||
self.env = DeviceAloha(self.robot_env)
|
||||
self.env_max_reward = 0
|
||||
else:
|
||||
from sim_env import make_sim_env
|
||||
|
||||
self.env = make_sim_env(self.task_name)
|
||||
self.env_max_reward = self.env.task.max_reward
|
||||
|
||||
def get_auto_index(self, checkpoint_dir):
|
||||
max_idx = 1000
|
||||
for i in range(max_idx + 1):
|
||||
if not os.path.isfile(os.path.join(checkpoint_dir, f"qpos_{i}.npy")):
|
||||
return i
|
||||
raise Exception(f"Error getting auto index, or more than {max_idx} episodes")
|
||||
|
||||
def evaluate(self, checkpoint_name=None):
|
||||
"""
|
||||
评估策略
|
||||
|
||||
Returns:
|
||||
tuple: 成功率和平均回报
|
||||
"""
|
||||
if checkpoint_name is not None:
|
||||
self.checkpoint_name = checkpoint_name
|
||||
set_seed(self._seed) # np与torch的随机种子
|
||||
self.load_policy_and_stats()
|
||||
self.load_environment()
|
||||
|
||||
query_frequency = self.policy_config["num_queries"]
|
||||
|
||||
# 时间聚合时,每个时间步只有1个查询
|
||||
if self.temporal_agg:
|
||||
query_frequency = 1
|
||||
num_queries = self.policy_config["num_queries"]
|
||||
|
||||
# # 真实机器人时,基础延迟为13???
|
||||
# if self.real_robot:
|
||||
# BASE_DELAY = 13
|
||||
# # query_frequency -= BASE_DELAY
|
||||
|
||||
max_timesteps = int(self.max_timesteps * 1) # may increase for real-world tasks
|
||||
episode_returns = []
|
||||
highest_rewards = []
|
||||
|
||||
for rollout_id in range(self.num_rollouts):
|
||||
|
||||
timestep = self.env.reset()
|
||||
|
||||
if self.onscreen_render:
|
||||
# TODO 画图
|
||||
pass
|
||||
if self.temporal_agg:
|
||||
all_time_actions = torch.zeros(
|
||||
[max_timesteps, max_timesteps + num_queries, self.state_dim]
|
||||
).cuda()
|
||||
qpos_history_raw = np.zeros((max_timesteps, self.state_dim))
|
||||
rewards = []
|
||||
|
||||
with torch.inference_mode():
|
||||
time_0 = time.time()
|
||||
DT = 1 / 30
|
||||
culmulated_delay = 0
|
||||
for t in range(max_timesteps):
|
||||
time_1 = time.time()
|
||||
if self.onscreen_render:
|
||||
# TODO 显示图像
|
||||
pass
|
||||
# process previous timestep to get qpos and image_list
|
||||
obs = timestep.observation
|
||||
qpos_numpy = np.array(obs["qpos"])
|
||||
qpos_history_raw[t] = qpos_numpy
|
||||
qpos = self.pre_process(qpos_numpy)
|
||||
qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0)
|
||||
|
||||
logging.info(f"t{t}")
|
||||
|
||||
if t % query_frequency == 0:
|
||||
current_image = self.get_image_torch(
|
||||
timestep,
|
||||
self.camera_names,
|
||||
random_crop_resize=(
|
||||
self.config["policy_class"] == "Diffusion"
|
||||
),
|
||||
)
|
||||
|
||||
if t == 0:
|
||||
# 网络预热
|
||||
for _ in range(10):
|
||||
self.policy(qpos, current_image)
|
||||
logging.info("Network warm up done")
|
||||
|
||||
if self.config["policy_class"] == "ACT":
|
||||
if t % query_frequency == 0:
|
||||
if self.vq:
|
||||
if rollout_id == 0:
|
||||
for _ in range(10):
|
||||
vq_sample = self.latent_model.generate(
|
||||
1, temperature=1, x=None
|
||||
)
|
||||
logging.info(
|
||||
torch.nonzero(vq_sample[0])[:, 1]
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
vq_sample = self.latent_model.generate(
|
||||
1, temperature=1, x=None
|
||||
)
|
||||
all_actions = self.policy(
|
||||
qpos, current_image, vq_sample=vq_sample
|
||||
)
|
||||
else:
|
||||
all_actions = self.policy(qpos, current_image)
|
||||
# if self.real_robot:
|
||||
# all_actions = torch.cat(
|
||||
# [
|
||||
# all_actions[:, :-BASE_DELAY, :-2],
|
||||
# all_actions[:, BASE_DELAY:, -2:],
|
||||
# ],
|
||||
# dim=2,
|
||||
# )
|
||||
if self.temporal_agg:
|
||||
all_time_actions[[t], t : t + num_queries] = all_actions
|
||||
actions_for_curr_step = all_time_actions[:, t]
|
||||
actions_populated = torch.all(
|
||||
actions_for_curr_step != 0, axis=1
|
||||
)
|
||||
actions_for_curr_step = actions_for_curr_step[
|
||||
actions_populated
|
||||
]
|
||||
k = 0.01
|
||||
exp_weights = np.exp(
|
||||
-k * np.arange(len(actions_for_curr_step))
|
||||
)
|
||||
exp_weights = exp_weights / exp_weights.sum()
|
||||
exp_weights = (
|
||||
torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
|
||||
)
|
||||
raw_action = (actions_for_curr_step * exp_weights).sum(
|
||||
dim=0, keepdim=True
|
||||
)
|
||||
else:
|
||||
raw_action = all_actions[:, t % query_frequency]
|
||||
elif self.config["policy_class"] == "Diffusion":
|
||||
if t % query_frequency == 0:
|
||||
all_actions = self.policy(qpos, current_image)
|
||||
# if self.real_robot:
|
||||
# all_actions = torch.cat(
|
||||
# [
|
||||
# all_actions[:, :-BASE_DELAY, :-2],
|
||||
# all_actions[:, BASE_DELAY:, -2:],
|
||||
# ],
|
||||
# dim=2,
|
||||
# )
|
||||
raw_action = all_actions[:, t % query_frequency]
|
||||
elif self.config["policy_class"] == "CNNMLP":
|
||||
raw_action = self.policy(qpos, current_image)
|
||||
all_actions = raw_action.unsqueeze(0)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
### post-process actions
|
||||
raw_action = raw_action.squeeze(0).cpu().numpy()
|
||||
action = self.post_process(raw_action)
|
||||
|
||||
### step the environment
|
||||
if self.real_robot:
|
||||
logging.info(f" action = {action}")
|
||||
timestep = self.env.step(action)
|
||||
|
||||
rewards.append(timestep.reward)
|
||||
duration = time.time() - time_1
|
||||
sleep_time = max(0, DT - duration)
|
||||
time.sleep(sleep_time)
|
||||
if duration >= DT:
|
||||
culmulated_delay += duration - DT
|
||||
logging.warning(
|
||||
f"Warning: step duration: {duration:.3f} s at step {t} longer than DT: {DT} s, culmulated delay: {culmulated_delay:.3f} s"
|
||||
)
|
||||
|
||||
logging.info(f"Avg fps {max_timesteps / (time.time() - time_0)}")
|
||||
plt.close()
|
||||
|
||||
if self.real_robot:
|
||||
log_id = self.get_auto_index(self.checkpoint_dir)
|
||||
np.save(
|
||||
os.path.join(self.checkpoint_dir, f"qpos_{log_id}.npy"),
|
||||
qpos_history_raw,
|
||||
)
|
||||
plt.figure(figsize=(10, 20))
|
||||
for i in range(self.state_dim):
|
||||
plt.subplot(self.state_dim, 1, i + 1)
|
||||
plt.plot(qpos_history_raw[:, i])
|
||||
if i != self.state_dim - 1:
|
||||
plt.xticks([])
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(self.checkpoint_dir, f"qpos_{log_id}.png"))
|
||||
plt.close()
|
||||
|
||||
rewards = np.array(rewards)
|
||||
episode_return = np.sum(rewards[rewards != None])
|
||||
episode_returns.append(episode_return)
|
||||
episode_highest_reward = np.max(rewards)
|
||||
highest_rewards.append(episode_highest_reward)
|
||||
logging.info(
|
||||
f"Rollout {rollout_id}\n{episode_return=}, {episode_highest_reward=}, {self.env_max_reward=}, Success: {episode_highest_reward == self.env_max_reward}"
|
||||
)
|
||||
|
||||
success_rate = np.mean(np.array(highest_rewards) == self.env_max_reward)
|
||||
avg_return = np.mean(episode_returns)
|
||||
summary_str = (
|
||||
f"\nSuccess rate: {success_rate}\nAverage return: {avg_return}\n\n"
|
||||
)
|
||||
for r in range(self.env_max_reward + 1):
|
||||
more_or_equal_r = (np.array(highest_rewards) >= r).sum()
|
||||
more_or_equal_r_rate = more_or_equal_r / self.num_rollouts
|
||||
summary_str += f"Reward >= {r}: {more_or_equal_r}/{self.num_rollouts} = {more_or_equal_r_rate * 100}%\n"
|
||||
|
||||
logging.info(summary_str)
|
||||
|
||||
result_file_name = "result_" + self.checkpoint_name.split(".")[0] + ".txt"
|
||||
with open(os.path.join(self.checkpoint_dir, result_file_name), "w") as f:
|
||||
f.write(summary_str)
|
||||
f.write(repr(episode_returns))
|
||||
f.write("\n\n")
|
||||
f.write(repr(highest_rewards))
|
||||
|
||||
return success_rate, avg_return
|
||||
|
||||
|
||||
class DeviceAloha:
|
||||
def __init__(self, aloha_config):
|
||||
"""
|
||||
初始化设备
|
||||
|
||||
Args:
|
||||
device_name (str): 设备名称
|
||||
"""
|
||||
config_left_arm = aloha_config["rm_left_arm"]
|
||||
config_right_arm = aloha_config["rm_right_arm"]
|
||||
config_head_camera = aloha_config["head_camera"]
|
||||
config_bottom_camera = aloha_config["bottom_camera"]
|
||||
config_left_camera = aloha_config["left_camera"]
|
||||
config_right_camera = aloha_config["right_camera"]
|
||||
self.init_left_arm_angle = aloha_config["init_left_arm_angle"]
|
||||
self.init_right_arm_angle = aloha_config["init_right_arm_angle"]
|
||||
self.arm_axis = aloha_config["arm_axis"]
|
||||
self.arm_left = RmArm(config_left_arm)
|
||||
self.arm_right = RmArm(config_right_arm)
|
||||
self.camera_left = RealSenseCamera(config_head_camera, False)
|
||||
self.camera_right = RealSenseCamera(config_bottom_camera, False)
|
||||
self.camera_bottom = RealSenseCamera(config_left_camera, False)
|
||||
self.camera_top = RealSenseCamera(config_right_camera, False)
|
||||
self.camera_left.start_camera()
|
||||
self.camera_right.start_camera()
|
||||
self.camera_bottom.start_camera()
|
||||
self.camera_top.start_camera()
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
关闭摄像头
|
||||
"""
|
||||
self.camera_left.close()
|
||||
self.camera_right.close()
|
||||
self.camera_bottom.close()
|
||||
self.camera_top.close()
|
||||
|
||||
def get_qps(self):
|
||||
"""
|
||||
获取关节角度
|
||||
|
||||
Returns:
|
||||
np.array: 关节角度
|
||||
"""
|
||||
left_slave_arm_angle = self.arm_left.get_joint_angle()
|
||||
left_joint_angles_array = np.array(list(left_slave_arm_angle.values()))
|
||||
right_slave_arm_angle = self.arm_right.get_joint_angle()
|
||||
right_joint_angles_array = np.array(list(right_slave_arm_angle.values()))
|
||||
return np.concatenate([left_joint_angles_array, right_joint_angles_array])
|
||||
|
||||
def get_qvel(self):
|
||||
"""
|
||||
获取关节速度
|
||||
|
||||
Returns:
|
||||
np.array: 关节速度
|
||||
"""
|
||||
left_slave_arm_velocity = self.arm_left.get_joint_velocity()
|
||||
left_joint_velocity_array = np.array(list(left_slave_arm_velocity.values()))
|
||||
right_slave_arm_velocity = self.arm_right.get_joint_velocity()
|
||||
right_joint_velocity_array = np.array(list(right_slave_arm_velocity.values()))
|
||||
return np.concatenate([left_joint_velocity_array, right_joint_velocity_array])
|
||||
|
||||
def get_effort(self):
|
||||
"""
|
||||
获取关节力
|
||||
|
||||
Returns:
|
||||
np.array: 关节力
|
||||
"""
|
||||
left_slave_arm_effort = self.arm_left.get_joint_effort()
|
||||
left_joint_effort_array = np.array(list(left_slave_arm_effort.values()))
|
||||
right_slave_arm_effort = self.arm_right.get_joint_effort()
|
||||
right_joint_effort_array = np.array(list(right_slave_arm_effort.values()))
|
||||
return np.concatenate([left_joint_effort_array, right_joint_effort_array])
|
||||
|
||||
def get_images(self):
|
||||
"""
|
||||
获取图像
|
||||
|
||||
Returns:
|
||||
dict: 图像字典
|
||||
"""
|
||||
self.top_image, _, _, _ = self.camera_top.read_frame(True, False, False, False)
|
||||
self.bottom_image, _, _, _ = self.camera_bottom.read_frame(
|
||||
True, False, False, False
|
||||
)
|
||||
self.left_image, _, _, _ = self.camera_left.read_frame(
|
||||
True, False, False, False
|
||||
)
|
||||
self.right_image, _, _, _ = self.camera_right.read_frame(
|
||||
True, False, False, False
|
||||
)
|
||||
return {
|
||||
"cam_high": self.top_image,
|
||||
"cam_low": self.bottom_image,
|
||||
"cam_left": self.left_image,
|
||||
"cam_right": self.right_image,
|
||||
}
|
||||
|
||||
def get_observation(self):
|
||||
obs = collections.OrderedDict()
|
||||
obs["qpos"] = self.get_qps()
|
||||
obs["qvel"] = self.get_qvel()
|
||||
obs["effort"] = self.get_effort()
|
||||
obs["images"] = self.get_images()
|
||||
return obs
|
||||
|
||||
def reset(self):
|
||||
logging.info("Resetting the environment")
|
||||
self.arm_left.set_joint_position(self.init_left_arm_angle[0:self.arm_axis])
|
||||
self.arm_right.set_joint_position(self.init_right_arm_angle[0:self.arm_axis])
|
||||
self.arm_left.set_gripper_position(0)
|
||||
self.arm_right.set_gripper_position(0)
|
||||
return dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.FIRST,
|
||||
reward=0,
|
||||
discount=None,
|
||||
observation=self.get_observation(),
|
||||
)
|
||||
|
||||
def step(self, target_angle):
|
||||
self.arm_left.set_joint_canfd_position(target_angle[0:self.arm_axis])
|
||||
self.arm_right.set_joint_canfd_position(target_angle[self.arm_axis+1:self.arm_axis*2+1])
|
||||
self.arm_left.set_gripper_position(target_angle[self.arm_axis])
|
||||
self.arm_right.set_gripper_position(target_angle[(self.arm_axis*2 + 1)])
|
||||
return dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.MID,
|
||||
reward=0,
|
||||
discount=None,
|
||||
observation=self.get_observation(),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# with open("/home/rm/code/shadow_act/config/config.yaml", "r") as f:
|
||||
# config = yaml.safe_load(f)
|
||||
# aloha_config = config["robot_env"]
|
||||
# device = DeviceAloha(aloha_config)
|
||||
# device.reset()
|
||||
# while True:
|
||||
# init_angle = np.concatenate([device.init_left_arm_angle, device.init_right_arm_angle])
|
||||
# time_step = time.time()
|
||||
# timestep = device.step(init_angle)
|
||||
# logging.info(f"Time: {time.time() - time_step}")
|
||||
# obs = timestep.observation
|
||||
|
||||
with open("/home/wang/project/shadow_rm_act/config/config.yaml", "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
# logging.info(f"Config: {config}")
|
||||
evaluator = RmActEvaluator(config)
|
||||
success_rate, avg_return = evaluator.evaluate()
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = '0.1.0'
|
||||
@@ -0,0 +1,153 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
Backbone modules.
|
||||
"""
|
||||
import torch
|
||||
import torchvision
|
||||
from torch import nn
|
||||
from typing import Dict, List
|
||||
import torch.nn.functional as F
|
||||
from .position_encoding import build_position_encoding
|
||||
from torchvision.models import ResNet18_Weights
|
||||
from torchvision.models._utils import IntermediateLayerGetter
|
||||
from shadow_act.utils.misc import NestedTensor, is_main_process
|
||||
|
||||
|
||||
class FrozenBatchNorm2d(torch.nn.Module):
|
||||
"""
|
||||
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
||||
|
||||
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
||||
without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
|
||||
produce nans.
|
||||
"""
|
||||
|
||||
def __init__(self, n):
|
||||
super(FrozenBatchNorm2d, self).__init__()
|
||||
self.register_buffer("weight", torch.ones(n))
|
||||
self.register_buffer("bias", torch.zeros(n))
|
||||
self.register_buffer("running_mean", torch.zeros(n))
|
||||
self.register_buffer("running_var", torch.ones(n))
|
||||
|
||||
def _load_from_state_dict(
|
||||
self,
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
num_batches_tracked_key = prefix + "num_batches_tracked"
|
||||
if num_batches_tracked_key in state_dict:
|
||||
del state_dict[num_batches_tracked_key]
|
||||
|
||||
super(FrozenBatchNorm2d, self)._load_from_state_dict(
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# move reshapes to the beginning
|
||||
# to make it fuser-friendly
|
||||
w = self.weight.reshape(1, -1, 1, 1)
|
||||
b = self.bias.reshape(1, -1, 1, 1)
|
||||
rv = self.running_var.reshape(1, -1, 1, 1)
|
||||
rm = self.running_mean.reshape(1, -1, 1, 1)
|
||||
eps = 1e-5
|
||||
scale = w * (rv + eps).rsqrt()
|
||||
bias = b - rm * scale
|
||||
return x * scale + bias
|
||||
|
||||
|
||||
class BackboneBase(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backbone: nn.Module,
|
||||
train_backbone: bool,
|
||||
num_channels: int,
|
||||
return_interm_layers: bool,
|
||||
):
|
||||
super().__init__()
|
||||
# for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
|
||||
# if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
||||
# parameter.requires_grad_(False)
|
||||
if return_interm_layers:
|
||||
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
|
||||
else:
|
||||
return_layers = {"layer4": "0"}
|
||||
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
||||
self.num_channels = num_channels
|
||||
|
||||
def forward(self, tensor):
|
||||
xs = self.body(tensor)
|
||||
return xs
|
||||
# out: Dict[str, NestedTensor] = {}
|
||||
# for name, x in xs.items():
|
||||
# m = tensor_list.mask
|
||||
# assert m is not None
|
||||
# mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
||||
# out[name] = NestedTensor(x, mask)
|
||||
# return out
|
||||
|
||||
|
||||
class Backbone(BackboneBase):
|
||||
"""ResNet backbone with frozen BatchNorm."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
train_backbone: bool,
|
||||
return_interm_layers: bool,
|
||||
dilation: bool,
|
||||
):
|
||||
backbone = getattr(torchvision.models, name)(
|
||||
replace_stride_with_dilation=[False, False, dilation],
|
||||
weights=ResNet18_Weights.IMAGENET1K_V1 if is_main_process() else None,
|
||||
norm_layer=FrozenBatchNorm2d,
|
||||
)
|
||||
# backbone = getattr(torchvision.models, name)(
|
||||
# replace_stride_with_dilation=[False, False, dilation],
|
||||
# pretrained=is_main_process(),
|
||||
# norm_layer=FrozenBatchNorm2d,
|
||||
# ) # pretrained # TODO do we want frozen batch_norm??
|
||||
num_channels = 512 if name in ("resnet18", "resnet34") else 2048
|
||||
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
|
||||
|
||||
|
||||
class Joiner(nn.Sequential):
|
||||
def __init__(self, backbone, position_embedding):
|
||||
super().__init__(backbone, position_embedding)
|
||||
|
||||
def forward(self, tensor_list: NestedTensor):
|
||||
xs = self[0](tensor_list)
|
||||
out: List[NestedTensor] = []
|
||||
pos = []
|
||||
for name, x in xs.items():
|
||||
out.append(x)
|
||||
# position encoding
|
||||
pos.append(self[1](x).to(x.dtype))
|
||||
|
||||
return out, pos
|
||||
|
||||
|
||||
def build_backbone(
|
||||
hidden_dim, position_embedding_type, lr_backbone, masks, backbone, dilation
|
||||
):
|
||||
|
||||
position_embedding = build_position_encoding(
|
||||
hidden_dim=hidden_dim, position_embedding_type=position_embedding_type
|
||||
)
|
||||
train_backbone = lr_backbone > 0
|
||||
return_interm_layers = masks
|
||||
backbone = Backbone(backbone, train_backbone, return_interm_layers, dilation)
|
||||
model = Joiner(backbone, position_embedding)
|
||||
model.num_channels = backbone.num_channels
|
||||
return model
|
||||
@@ -0,0 +1,436 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
DETR model and criterion classes.
|
||||
"""
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd import Variable
|
||||
import torch.nn.functional as F
|
||||
from shadow_act.models.transformer import Transformer
|
||||
from .backbone import build_backbone
|
||||
from .transformer import build_transformer, TransformerEncoder, TransformerEncoderLayer
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
def reparametrize(mu, logvar):
|
||||
std = logvar.div(2).exp()
|
||||
eps = Variable(std.data.new(std.size()).normal_())
|
||||
return mu + std * eps
|
||||
|
||||
|
||||
def get_sinusoid_encoding_table(n_position, d_hid):
|
||||
def get_position_angle_vec(position):
|
||||
return [
|
||||
position / np.power(10000, 2 * (hid_j // 2) / d_hid)
|
||||
for hid_j in range(d_hid)
|
||||
]
|
||||
|
||||
sinusoid_table = np.array(
|
||||
[get_position_angle_vec(pos_i) for pos_i in range(n_position)]
|
||||
)
|
||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||
|
||||
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
||||
|
||||
|
||||
class DETRVAE(nn.Module):
|
||||
"""This is the DETR module that performs object detection"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backbones,
|
||||
transformer,
|
||||
encoder,
|
||||
state_dim,
|
||||
num_queries,
|
||||
camera_names,
|
||||
vq,
|
||||
vq_class,
|
||||
vq_dim,
|
||||
action_dim,
|
||||
):
|
||||
"""Initializes the model.
|
||||
Parameters:
|
||||
backbones: torch module of the backbone to be used. See backbone.py
|
||||
transformer: torch module of the transformer architecture. See transformer.py
|
||||
state_dim: robot state dimension of the environment
|
||||
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
||||
DETR can detect in a single image. For COCO, we recommend 100 queries.
|
||||
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_queries = num_queries
|
||||
self.camera_names = camera_names
|
||||
self.transformer = transformer
|
||||
self.encoder = encoder
|
||||
self.vq, self.vq_class, self.vq_dim = vq, vq_class, vq_dim
|
||||
self.state_dim, self.action_dim = state_dim, action_dim
|
||||
hidden_dim = transformer.d_model
|
||||
self.action_head = nn.Linear(hidden_dim, action_dim)
|
||||
self.is_pad_head = nn.Linear(hidden_dim, 1)
|
||||
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
||||
if backbones is not None:
|
||||
self.input_proj = nn.Conv2d(
|
||||
backbones[0].num_channels, hidden_dim, kernel_size=1
|
||||
)
|
||||
self.backbones = nn.ModuleList(backbones)
|
||||
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
||||
else:
|
||||
# input_dim = 14 + 7 # robot_state + env_state
|
||||
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
||||
self.input_proj_env_state = nn.Linear(7, hidden_dim)
|
||||
self.pos = torch.nn.Embedding(2, hidden_dim)
|
||||
self.backbones = None
|
||||
|
||||
# encoder extra parameters
|
||||
self.latent_dim = 32 # final size of latent z # TODO tune
|
||||
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
|
||||
self.encoder_action_proj = nn.Linear(
|
||||
action_dim, hidden_dim
|
||||
) # project action to embedding
|
||||
self.encoder_joint_proj = nn.Linear(
|
||||
action_dim, hidden_dim
|
||||
) # project qpos to embedding
|
||||
if self.vq:
|
||||
self.latent_proj = nn.Linear(hidden_dim, self.vq_class * self.vq_dim)
|
||||
else:
|
||||
self.latent_proj = nn.Linear(
|
||||
hidden_dim, self.latent_dim * 2
|
||||
) # project hidden state to latent std, var
|
||||
self.register_buffer(
|
||||
"pos_table", get_sinusoid_encoding_table(1 + 1 + num_queries, hidden_dim)
|
||||
) # [CLS], qpos, a_seq
|
||||
|
||||
# decoder extra parameters
|
||||
if self.vq:
|
||||
self.latent_out_proj = nn.Linear(self.vq_class * self.vq_dim, hidden_dim)
|
||||
else:
|
||||
self.latent_out_proj = nn.Linear(
|
||||
self.latent_dim, hidden_dim
|
||||
) # project latent sample to embedding
|
||||
self.additional_pos_embed = nn.Embedding(
|
||||
2, hidden_dim
|
||||
) # learned position embedding for proprio and latent
|
||||
|
||||
def encode(self, qpos, actions=None, is_pad=None, vq_sample=None):
|
||||
bs, _ = qpos.shape
|
||||
if self.encoder is None:
|
||||
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(
|
||||
qpos.device
|
||||
)
|
||||
latent_input = self.latent_out_proj(latent_sample)
|
||||
probs = binaries = mu = logvar = None
|
||||
else:
|
||||
# cvae encoder
|
||||
is_training = actions is not None # train or val
|
||||
### Obtain latent z from action sequence
|
||||
if is_training:
|
||||
# project action sequence to embedding dim, and concat with a CLS token
|
||||
action_embed = self.encoder_action_proj(
|
||||
actions
|
||||
) # (bs, seq, hidden_dim)
|
||||
qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
|
||||
qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
|
||||
cls_embed = self.cls_embed.weight # (1, hidden_dim)
|
||||
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(
|
||||
bs, 1, 1
|
||||
) # (bs, 1, hidden_dim)
|
||||
encoder_input = torch.cat(
|
||||
[cls_embed, qpos_embed, action_embed], axis=1
|
||||
) # (bs, seq+1, hidden_dim)
|
||||
encoder_input = encoder_input.permute(
|
||||
1, 0, 2
|
||||
) # (seq+1, bs, hidden_dim)
|
||||
# do not mask cls token
|
||||
cls_joint_is_pad = torch.full((bs, 2), False).to(
|
||||
qpos.device
|
||||
) # False: not a padding
|
||||
is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
|
||||
# obtain position embedding
|
||||
pos_embed = self.pos_table.clone().detach()
|
||||
pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
|
||||
# query model
|
||||
encoder_output = self.encoder(
|
||||
encoder_input, pos=pos_embed, src_key_padding_mask=is_pad
|
||||
)
|
||||
encoder_output = encoder_output[0] # take cls output only
|
||||
latent_info = self.latent_proj(encoder_output)
|
||||
|
||||
if self.vq:
|
||||
logits = latent_info.reshape(
|
||||
[*latent_info.shape[:-1], self.vq_class, self.vq_dim]
|
||||
)
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
binaries = (
|
||||
F.one_hot(
|
||||
torch.multinomial(probs.view(-1, self.vq_dim), 1).squeeze(
|
||||
-1
|
||||
),
|
||||
self.vq_dim,
|
||||
)
|
||||
.view(-1, self.vq_class, self.vq_dim)
|
||||
.float()
|
||||
)
|
||||
binaries_flat = binaries.view(-1, self.vq_class * self.vq_dim)
|
||||
probs_flat = probs.view(-1, self.vq_class * self.vq_dim)
|
||||
straigt_through = binaries_flat - probs_flat.detach() + probs_flat
|
||||
latent_input = self.latent_out_proj(straigt_through)
|
||||
mu = logvar = None
|
||||
else:
|
||||
probs = binaries = None
|
||||
mu = latent_info[:, : self.latent_dim]
|
||||
logvar = latent_info[:, self.latent_dim :]
|
||||
latent_sample = reparametrize(mu, logvar)
|
||||
latent_input = self.latent_out_proj(latent_sample)
|
||||
|
||||
else:
|
||||
mu = logvar = binaries = probs = None
|
||||
if self.vq:
|
||||
latent_input = self.latent_out_proj(
|
||||
vq_sample.view(-1, self.vq_class * self.vq_dim)
|
||||
)
|
||||
else:
|
||||
latent_sample = torch.zeros(
|
||||
[bs, self.latent_dim], dtype=torch.float32
|
||||
).to(qpos.device)
|
||||
latent_input = self.latent_out_proj(latent_sample)
|
||||
|
||||
return latent_input, probs, binaries, mu, logvar
|
||||
|
||||
def forward(
|
||||
self, qpos, image, env_state, actions=None, is_pad=None, vq_sample=None
|
||||
):
|
||||
"""
|
||||
qpos: batch, qpos_dim
|
||||
image: batch, num_cam, channel, height, width
|
||||
env_state: None
|
||||
actions: batch, seq, action_dim
|
||||
"""
|
||||
|
||||
latent_input, probs, binaries, mu, logvar = self.encode(
|
||||
qpos, actions, is_pad, vq_sample
|
||||
)
|
||||
|
||||
# cvae decoder
|
||||
if self.backbones is not None:
|
||||
# Image observation features and position embeddings
|
||||
all_cam_features = []
|
||||
all_cam_pos = []
|
||||
for cam_id, cam_name in enumerate(self.camera_names):
|
||||
# TODO: fix this error
|
||||
features, pos = self.backbones[0](image[:, cam_id])
|
||||
features = features[0] # take the last layer feature
|
||||
pos = pos[0]
|
||||
all_cam_features.append(self.input_proj(features))
|
||||
all_cam_pos.append(pos)
|
||||
# proprioception features
|
||||
proprio_input = self.input_proj_robot_state(qpos)
|
||||
# fold camera dimension into width dimension
|
||||
src = torch.cat(all_cam_features, axis=3)
|
||||
pos = torch.cat(all_cam_pos, axis=3)
|
||||
hs = self.transformer(
|
||||
src,
|
||||
None,
|
||||
self.query_embed.weight,
|
||||
pos,
|
||||
latent_input,
|
||||
proprio_input,
|
||||
self.additional_pos_embed.weight,
|
||||
)[0]
|
||||
else:
|
||||
qpos = self.input_proj_robot_state(qpos)
|
||||
env_state = self.input_proj_env_state(env_state)
|
||||
transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
|
||||
hs = self.transformer(
|
||||
transformer_input, None, self.query_embed.weight, self.pos.weight
|
||||
)[0]
|
||||
a_hat = self.action_head(hs)
|
||||
is_pad_hat = self.is_pad_head(hs)
|
||||
return a_hat, is_pad_hat, [mu, logvar], probs, binaries
|
||||
|
||||
|
||||
class CNNMLP(nn.Module):
|
||||
def __init__(self, backbones, state_dim, camera_names):
|
||||
"""Initializes the model.
|
||||
Parameters:
|
||||
backbones: torch module of the backbone to be used. See backbone.py
|
||||
transformer: torch module of the transformer architecture. See transformer.py
|
||||
state_dim: robot state dimension of the environment
|
||||
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
||||
DETR can detect in a single image. For COCO, we recommend 100 queries.
|
||||
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
||||
"""
|
||||
super().__init__()
|
||||
self.camera_names = camera_names
|
||||
self.action_head = nn.Linear(1000, state_dim) # TODO add more
|
||||
if backbones is not None:
|
||||
self.backbones = nn.ModuleList(backbones)
|
||||
backbone_down_projs = []
|
||||
for backbone in backbones:
|
||||
down_proj = nn.Sequential(
|
||||
nn.Conv2d(backbone.num_channels, 128, kernel_size=5),
|
||||
nn.Conv2d(128, 64, kernel_size=5),
|
||||
nn.Conv2d(64, 32, kernel_size=5),
|
||||
)
|
||||
backbone_down_projs.append(down_proj)
|
||||
self.backbone_down_projs = nn.ModuleList(backbone_down_projs)
|
||||
|
||||
mlp_in_dim = 768 * len(backbones) + state_dim
|
||||
self.mlp = mlp(
|
||||
input_dim=mlp_in_dim,
|
||||
hidden_dim=1024,
|
||||
output_dim=self.action_dim,
|
||||
hidden_depth=2,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, qpos, image, env_state, actions=None):
|
||||
"""
|
||||
qpos: batch, qpos_dim
|
||||
image: batch, num_cam, channel, height, width
|
||||
env_state: None
|
||||
actions: batch, seq, action_dim
|
||||
"""
|
||||
is_training = actions is not None # train or val
|
||||
bs, _ = qpos.shape
|
||||
# Image observation features and position embeddings
|
||||
all_cam_features = []
|
||||
for cam_id, cam_name in enumerate(self.camera_names):
|
||||
features, pos = self.backbones[cam_id](image[:, cam_id])
|
||||
features = features[0] # take the last layer feature
|
||||
pos = pos[0] # not used
|
||||
all_cam_features.append(self.backbone_down_projs[cam_id](features))
|
||||
# flatten everything
|
||||
flattened_features = []
|
||||
for cam_feature in all_cam_features:
|
||||
flattened_features.append(cam_feature.reshape([bs, -1]))
|
||||
flattened_features = torch.cat(flattened_features, axis=1) # 768 each
|
||||
features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14
|
||||
a_hat = self.mlp(features)
|
||||
return a_hat
|
||||
|
||||
|
||||
def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
|
||||
if hidden_depth == 0:
|
||||
mods = [nn.Linear(input_dim, output_dim)]
|
||||
else:
|
||||
mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
|
||||
for i in range(hidden_depth - 1):
|
||||
mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
|
||||
mods.append(nn.Linear(hidden_dim, output_dim))
|
||||
trunk = nn.Sequential(*mods)
|
||||
return trunk
|
||||
|
||||
|
||||
def build_encoder(
|
||||
hidden_dim, # 256
|
||||
dropout, # 0.1
|
||||
nheads, # 8
|
||||
dim_feedforward,
|
||||
num_encoder_layers, # 4 # TODO shared with VAE decoder
|
||||
normalize_before, # False
|
||||
):
|
||||
activation = "relu"
|
||||
|
||||
encoder_layer = TransformerEncoderLayer(
|
||||
hidden_dim, nheads, dim_feedforward, dropout, activation, normalize_before
|
||||
)
|
||||
encoder_norm = nn.LayerNorm(hidden_dim) if normalize_before else None
|
||||
encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
||||
|
||||
return encoder
|
||||
|
||||
|
||||
def build_vae(
|
||||
hidden_dim,
|
||||
state_dim,
|
||||
position_embedding_type,
|
||||
lr_backbone,
|
||||
masks,
|
||||
backbone,
|
||||
dilation,
|
||||
dropout,
|
||||
nheads,
|
||||
dim_feedforward,
|
||||
enc_layers,
|
||||
dec_layers,
|
||||
pre_norm,
|
||||
num_queries,
|
||||
camera_names,
|
||||
vq,
|
||||
vq_class,
|
||||
vq_dim,
|
||||
action_dim,
|
||||
no_encoder,
|
||||
):
|
||||
# TODO hardcode
|
||||
|
||||
# From state
|
||||
# backbone = None # from state for now, no need for conv nets
|
||||
# From image
|
||||
backbones = []
|
||||
backbone = build_backbone(
|
||||
hidden_dim, position_embedding_type, lr_backbone, masks, backbone, dilation
|
||||
)
|
||||
backbones.append(backbone)
|
||||
|
||||
transformer = build_transformer(
|
||||
hidden_dim, dropout, nheads, dim_feedforward, enc_layers, dec_layers, pre_norm
|
||||
)
|
||||
|
||||
if no_encoder:
|
||||
encoder = None
|
||||
else:
|
||||
encoder = build_encoder(
|
||||
hidden_dim,
|
||||
dropout,
|
||||
nheads,
|
||||
dim_feedforward,
|
||||
enc_layers,
|
||||
pre_norm,
|
||||
)
|
||||
|
||||
model = DETRVAE(
|
||||
backbones,
|
||||
transformer,
|
||||
encoder,
|
||||
state_dim,
|
||||
num_queries,
|
||||
camera_names,
|
||||
vq,
|
||||
vq_class,
|
||||
vq_dim,
|
||||
action_dim,
|
||||
)
|
||||
|
||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print("number of parameters: %.2fM" % (n_parameters / 1e6,))
|
||||
|
||||
return model
|
||||
|
||||
# TODO
|
||||
def build_cnnmlp(args):
|
||||
state_dim = 14 # TODO hardcode
|
||||
|
||||
# From state
|
||||
# backbone = None # from state for now, no need for conv nets
|
||||
# From image
|
||||
backbones = []
|
||||
for _ in args.camera_names:
|
||||
backbone = build_backbone(args)
|
||||
backbones.append(backbone)
|
||||
|
||||
model = CNNMLP(
|
||||
backbones,
|
||||
state_dim=state_dim,
|
||||
camera_names=args.camera_names,
|
||||
)
|
||||
|
||||
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
print("number of parameters: %.2fM" % (n_parameters / 1e6,))
|
||||
|
||||
return model
|
||||
@@ -0,0 +1,122 @@
|
||||
#!/usr/bin/env python3
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
import torch
|
||||
|
||||
DROPOUT_RATE = 0.1 # 定义 dropout 率
|
||||
|
||||
# 定义一个因果变压器块
|
||||
class Causal_Transformer_Block(nn.Module):
|
||||
def __init__(self, seq_len, latent_dim, num_head) -> None:
|
||||
"""
|
||||
初始化因果变压器块
|
||||
|
||||
Args:
|
||||
seq_len (int): 序列长度
|
||||
latent_dim (int): 潜在维度
|
||||
num_head (int): 注意力头的数量
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_head = num_head
|
||||
self.latent_dim = latent_dim
|
||||
self.ln_1 = nn.LayerNorm(latent_dim) # 层归一化
|
||||
self.attn = nn.MultiheadAttention(latent_dim, num_head, dropout=DROPOUT_RATE, batch_first=True) # 多头注意力机制
|
||||
self.ln_2 = nn.LayerNorm(latent_dim) # 层归一化
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(latent_dim, 4 * latent_dim), # 全连接层
|
||||
nn.GELU(), # GELU 激活函数
|
||||
nn.Linear(4 * latent_dim, latent_dim), # 全连接层
|
||||
nn.Dropout(DROPOUT_RATE), # Dropout
|
||||
)
|
||||
|
||||
# self.register_buffer("attn_mask", torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()) # 注册注意力掩码
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
前向传播
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): 输入张量
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 输出张量
|
||||
"""
|
||||
# 创建上三角掩码,防止信息泄露
|
||||
attn_mask = torch.triu(torch.ones(x.shape[1], x.shape[1], device=x.device, dtype=torch.bool), diagonal=1)
|
||||
x = self.ln_1(x) # 层归一化
|
||||
x = x + self.attn(x, x, x, attn_mask=attn_mask)[0] # 加上注意力输出
|
||||
x = self.ln_2(x) # 层归一化
|
||||
x = x + self.mlp(x) # 加上 MLP 输出
|
||||
|
||||
return x
|
||||
|
||||
# 使用自注意力机制而不是 RNN 来建模潜在空间序列
|
||||
class Latent_Model_Transformer(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, seq_len, latent_dim=256, num_head=8, num_layer=3) -> None:
|
||||
"""
|
||||
初始化潜在模型变压器
|
||||
|
||||
Args:
|
||||
input_dim (int): 输入维度
|
||||
output_dim (int): 输出维度
|
||||
seq_len (int): 序列长度
|
||||
latent_dim (int, optional): 潜在维度,默认值为 256
|
||||
num_head (int, optional): 注意力头的数量,默认值为 8
|
||||
num_layer (int, optional): 变压器层的数量,默认值为 3
|
||||
"""
|
||||
super().__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.seq_len = seq_len
|
||||
self.latent_dim = latent_dim
|
||||
self.num_head = num_head
|
||||
self.num_layer = num_layer
|
||||
self.input_layer = nn.Linear(input_dim, latent_dim) # 输入层
|
||||
self.weight_pos_embed = nn.Embedding(seq_len, latent_dim) # 位置嵌入
|
||||
self.attention_blocks = nn.Sequential(
|
||||
nn.Dropout(DROPOUT_RATE), # Dropout
|
||||
*[Causal_Transformer_Block(seq_len, latent_dim, num_head) for _ in range(num_layer)], # 多个因果变压器块
|
||||
nn.LayerNorm(latent_dim) # 层归一化
|
||||
)
|
||||
self.output_layer = nn.Linear(latent_dim, output_dim) # 输出层
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
前向传播
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): 输入张量
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 输出张量
|
||||
"""
|
||||
x = self.input_layer(x) # 输入层
|
||||
x = x + self.weight_pos_embed(torch.arange(x.shape[1], device=x.device)) # 加上位置嵌入
|
||||
x = self.attention_blocks(x) # 通过注意力块
|
||||
logits = self.output_layer(x) # 输出层
|
||||
|
||||
return logits
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, n, temperature=0.1, x=None):
|
||||
"""
|
||||
生成序列
|
||||
|
||||
Args:
|
||||
n (int): 生成序列的数量
|
||||
temperature (float, optional): 采样温度,默认值为 0.1
|
||||
x (torch.Tensor, optional): 初始输入张量,默认值为 None
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 生成的序列
|
||||
"""
|
||||
if x is None:
|
||||
x = torch.zeros((n, 1, self.input_dim), device=self.weight_pos_embed.weight.device) # 初始化输入
|
||||
for i in range(self.seq_len):
|
||||
logits = self.forward(x)[:, -1] # 获取最后一个时间步的输出
|
||||
probs = torch.softmax(logits / temperature, dim=-1) # 计算概率分布
|
||||
samples = torch.multinomial(probs, num_samples=1)[..., 0] # 从概率分布中采样
|
||||
samples_one_hot = F.one_hot(samples.long(), num_classes=self.output_dim).float() # 转为 one-hot 编码
|
||||
x = torch.cat([x, samples_one_hot[:, None, :]], dim=1) # 将新采样的结果添加到输入中
|
||||
|
||||
return x[:, 1:, :] # 返回生成的序列(去掉初始的零输入)
|
||||
@@ -0,0 +1,91 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
Various positional encodings for the transformer.
|
||||
"""
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from shadow_act.utils.misc import NestedTensor
|
||||
|
||||
|
||||
class PositionEmbeddingSine(nn.Module):
|
||||
"""
|
||||
This is a more standard version of the position embedding, very similar to the one
|
||||
used by the Attention is all you need paper, generalized to work on images.
|
||||
"""
|
||||
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
||||
super().__init__()
|
||||
self.num_pos_feats = num_pos_feats
|
||||
self.temperature = temperature
|
||||
self.normalize = normalize
|
||||
if scale is not None and normalize is False:
|
||||
raise ValueError("normalize should be True if scale is passed")
|
||||
if scale is None:
|
||||
scale = 2 * math.pi
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, tensor):
|
||||
x = tensor
|
||||
# mask = tensor_list.mask
|
||||
# assert mask is not None
|
||||
# not_mask = ~mask
|
||||
|
||||
not_mask = torch.ones_like(x[0, [0]])
|
||||
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
||||
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
||||
if self.normalize:
|
||||
eps = 1e-6
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, :, None] / dim_t
|
||||
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||
return pos
|
||||
|
||||
|
||||
class PositionEmbeddingLearned(nn.Module):
|
||||
"""
|
||||
Absolute pos embedding, learned.
|
||||
"""
|
||||
def __init__(self, num_pos_feats=256):
|
||||
super().__init__()
|
||||
self.row_embed = nn.Embedding(50, num_pos_feats)
|
||||
self.col_embed = nn.Embedding(50, num_pos_feats)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.uniform_(self.row_embed.weight)
|
||||
nn.init.uniform_(self.col_embed.weight)
|
||||
|
||||
def forward(self, tensor_list: NestedTensor):
|
||||
x = tensor_list.tensors
|
||||
h, w = x.shape[-2:]
|
||||
i = torch.arange(w, device=x.device)
|
||||
j = torch.arange(h, device=x.device)
|
||||
x_emb = self.col_embed(i)
|
||||
y_emb = self.row_embed(j)
|
||||
pos = torch.cat([
|
||||
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
||||
y_emb.unsqueeze(1).repeat(1, w, 1),
|
||||
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
|
||||
return pos
|
||||
|
||||
|
||||
def build_position_encoding(hidden_dim, position_embedding_type):
|
||||
N_steps = hidden_dim // 2
|
||||
if position_embedding_type in ('v2', 'sine'):
|
||||
# TODO find a better way of exposing other arguments
|
||||
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
|
||||
elif position_embedding_type in ('v3', 'learned'):
|
||||
position_embedding = PositionEmbeddingLearned(N_steps)
|
||||
else:
|
||||
raise ValueError(f"not supported {position_embedding_type}")
|
||||
|
||||
return position_embedding
|
||||
@@ -0,0 +1,424 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
DETR Transformer class.
|
||||
|
||||
Copy-paste from torch.nn.Transformer with modifications:
|
||||
* positional encodings are passed in MHattention
|
||||
* extra LN at the end of encoder is removed
|
||||
* decoder returns a stack of activations from all decoding layers
|
||||
"""
|
||||
import copy
|
||||
from typing import Optional, List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, Tensor
|
||||
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model=512,
|
||||
nhead=8,
|
||||
num_encoder_layers=6,
|
||||
num_decoder_layers=6,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
activation="relu",
|
||||
normalize_before=False,
|
||||
return_intermediate_dec=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
encoder_layer = TransformerEncoderLayer(
|
||||
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||
)
|
||||
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
||||
self.encoder = TransformerEncoder(
|
||||
encoder_layer, num_encoder_layers, encoder_norm
|
||||
)
|
||||
|
||||
decoder_layer = TransformerDecoderLayer(
|
||||
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||
)
|
||||
decoder_norm = nn.LayerNorm(d_model)
|
||||
self.decoder = TransformerDecoder(
|
||||
decoder_layer,
|
||||
num_decoder_layers,
|
||||
decoder_norm,
|
||||
return_intermediate=return_intermediate_dec,
|
||||
)
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
self.d_model = d_model
|
||||
self.nhead = nhead
|
||||
|
||||
def _reset_parameters(self):
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src,
|
||||
mask,
|
||||
query_embed,
|
||||
pos_embed,
|
||||
latent_input=None,
|
||||
proprio_input=None,
|
||||
additional_pos_embed=None,
|
||||
):
|
||||
# TODO flatten only when input has H and W
|
||||
if len(src.shape) == 4: # has H and W
|
||||
# flatten NxCxHxW to HWxNxC
|
||||
bs, c, h, w = src.shape
|
||||
src = src.flatten(2).permute(2, 0, 1)
|
||||
pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
|
||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
# mask = mask.flatten(1)
|
||||
|
||||
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(
|
||||
1, bs, 1
|
||||
) # seq, bs, dim
|
||||
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
|
||||
|
||||
addition_input = torch.stack([latent_input, proprio_input], axis=0)
|
||||
src = torch.cat([addition_input, src], axis=0)
|
||||
else:
|
||||
assert len(src.shape) == 3
|
||||
# flatten NxHWxC to HWxNxC
|
||||
bs, hw, c = src.shape
|
||||
src = src.permute(1, 0, 2)
|
||||
pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||
|
||||
tgt = torch.zeros_like(query_embed)
|
||||
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
||||
hs = self.decoder(
|
||||
tgt,
|
||||
memory,
|
||||
memory_key_padding_mask=mask,
|
||||
pos=pos_embed,
|
||||
query_pos=query_embed,
|
||||
)
|
||||
hs = hs.transpose(1, 2)
|
||||
return hs
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Module):
|
||||
|
||||
def __init__(self, encoder_layer, num_layers, norm=None):
|
||||
super().__init__()
|
||||
self.layers = _get_clones(encoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
self.norm = norm
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src,
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
output = src
|
||||
|
||||
for layer in self.layers:
|
||||
output = layer(
|
||||
output,
|
||||
src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
pos=pos,
|
||||
)
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
|
||||
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
||||
super().__init__()
|
||||
self.layers = _get_clones(decoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
self.norm = norm
|
||||
self.return_intermediate = return_intermediate
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
output = tgt
|
||||
|
||||
intermediate = []
|
||||
|
||||
for layer in self.layers:
|
||||
output = layer(
|
||||
output,
|
||||
memory,
|
||||
tgt_mask=tgt_mask,
|
||||
memory_mask=memory_mask,
|
||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
pos=pos,
|
||||
query_pos=query_pos,
|
||||
)
|
||||
if self.return_intermediate:
|
||||
intermediate.append(self.norm(output))
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
if self.return_intermediate:
|
||||
intermediate.pop()
|
||||
intermediate.append(output)
|
||||
|
||||
if self.return_intermediate:
|
||||
return torch.stack(intermediate)
|
||||
|
||||
return output.unsqueeze(0)
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model,
|
||||
nhead,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
activation="relu",
|
||||
normalize_before=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward_post(
|
||||
self,
|
||||
src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
q = k = self.with_pos_embed(src, pos)
|
||||
src2 = self.self_attn(
|
||||
q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
|
||||
)[0]
|
||||
src = src + self.dropout1(src2)
|
||||
src = self.norm1(src)
|
||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
||||
src = src + self.dropout2(src2)
|
||||
src = self.norm2(src)
|
||||
return src
|
||||
|
||||
def forward_pre(
|
||||
self,
|
||||
src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
src2 = self.norm1(src)
|
||||
q = k = self.with_pos_embed(src2, pos)
|
||||
src2 = self.self_attn(
|
||||
q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
|
||||
)[0]
|
||||
src = src + self.dropout1(src2)
|
||||
src2 = self.norm2(src)
|
||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
||||
src = src + self.dropout2(src2)
|
||||
return src
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
):
|
||||
if self.normalize_before:
|
||||
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
||||
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model,
|
||||
nhead,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
activation="relu",
|
||||
normalize_before=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
self.dropout3 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward_post(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
q = k = self.with_pos_embed(tgt, query_pos)
|
||||
tgt2 = self.self_attn(
|
||||
q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
||||
)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt = self.norm1(tgt)
|
||||
tgt2 = self.multihead_attn(
|
||||
query=self.with_pos_embed(tgt, query_pos),
|
||||
key=self.with_pos_embed(memory, pos),
|
||||
value=memory,
|
||||
attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask,
|
||||
)[0]
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt = self.norm2(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
tgt = self.norm3(tgt)
|
||||
return tgt
|
||||
|
||||
def forward_pre(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
tgt2 = self.norm1(tgt)
|
||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
||||
tgt2 = self.self_attn(
|
||||
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
||||
)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt2 = self.norm2(tgt)
|
||||
tgt2 = self.multihead_attn(
|
||||
query=self.with_pos_embed(tgt2, query_pos),
|
||||
key=self.with_pos_embed(memory, pos),
|
||||
value=memory,
|
||||
attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask,
|
||||
)[0]
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt2 = self.norm3(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
return tgt
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
pos: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None,
|
||||
):
|
||||
if self.normalize_before:
|
||||
return self.forward_pre(
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask,
|
||||
memory_mask,
|
||||
tgt_key_padding_mask,
|
||||
memory_key_padding_mask,
|
||||
pos,
|
||||
query_pos,
|
||||
)
|
||||
return self.forward_post(
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask,
|
||||
memory_mask,
|
||||
tgt_key_padding_mask,
|
||||
memory_key_padding_mask,
|
||||
pos,
|
||||
query_pos,
|
||||
)
|
||||
|
||||
|
||||
def _get_clones(module, N):
|
||||
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
||||
|
||||
|
||||
def build_transformer(
|
||||
hidden_dim, dropout, nheads, dim_feedforward, enc_layers, dec_layers, pre_norm
|
||||
):
|
||||
return Transformer(
|
||||
d_model=hidden_dim,
|
||||
dropout=dropout,
|
||||
nhead=nheads,
|
||||
dim_feedforward=dim_feedforward,
|
||||
num_encoder_layers=enc_layers,
|
||||
num_decoder_layers=dec_layers,
|
||||
normalize_before=pre_norm,
|
||||
return_intermediate_dec=True,
|
||||
)
|
||||
|
||||
|
||||
def _get_activation_fn(activation):
|
||||
"""Return an activation function given a string"""
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
if activation == "gelu":
|
||||
return F.gelu
|
||||
if activation == "glu":
|
||||
return F.glu
|
||||
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = '0.1.0'
|
||||
@@ -0,0 +1,522 @@
|
||||
import torch
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
import torchvision.transforms as transforms
|
||||
from shadow_act.models.detr_vae import build_vae, build_cnnmlp
|
||||
|
||||
# from diffusers.training_utils import EMAModel
|
||||
# from robomimic.models.base_nets import ResNet18Conv, SpatialSoftmax
|
||||
# from robomimic.algo.diffusion_policy import replace_bn_with_gn, ConditionalUnet1D
|
||||
|
||||
# from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
# from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
|
||||
# 配置日志记录
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# TODO: 重构DiffusionPolicy类
|
||||
class DiffusionPolicy(nn.Module):
|
||||
def __init__(self, args_override):
|
||||
"""
|
||||
初始化DiffusionPolicy类
|
||||
|
||||
Args:
|
||||
args_override (dict): 参数覆盖字典
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.camera_names = args_override["camera_names"]
|
||||
self.observation_horizon = args_override["observation_horizon"]
|
||||
self.action_horizon = args_override["action_horizon"]
|
||||
self.prediction_horizon = args_override["prediction_horizon"]
|
||||
self.num_inference_timesteps = args_override["num_inference_timesteps"]
|
||||
self.ema_power = args_override["ema_power"]
|
||||
self.lr = args_override["lr"]
|
||||
self.weight_decay = 0
|
||||
|
||||
self.num_kp = 32
|
||||
self.feature_dimension = 64
|
||||
self.ac_dim = args_override["action_dim"]
|
||||
self.obs_dim = self.feature_dimension * len(self.camera_names) + 14
|
||||
|
||||
backbones = []
|
||||
pools = []
|
||||
linears = []
|
||||
for _ in self.camera_names:
|
||||
backbones.append(
|
||||
ResNet18Conv(input_channel=3, pretrained=False, input_coord_conv=False)
|
||||
)
|
||||
pools.append(
|
||||
SpatialSoftmax(
|
||||
input_shape=[512, 15, 20],
|
||||
num_kp=self.num_kp,
|
||||
temperature=1.0,
|
||||
learnable_temperature=False,
|
||||
noise_std=0.0,
|
||||
)
|
||||
)
|
||||
linears.append(
|
||||
torch.nn.Linear(int(np.prod([self.num_kp, 2])), self.feature_dimension)
|
||||
)
|
||||
backbones = nn.ModuleList(backbones)
|
||||
pools = nn.ModuleList(pools)
|
||||
linears = nn.ModuleList(linears)
|
||||
|
||||
backbones = replace_bn_with_gn(backbones)
|
||||
|
||||
noise_pred_net = ConditionalUnet1D(
|
||||
input_dim=self.ac_dim,
|
||||
global_cond_dim=self.obs_dim * self.observation_horizon,
|
||||
)
|
||||
|
||||
nets = nn.ModuleDict(
|
||||
{
|
||||
"policy": nn.ModuleDict(
|
||||
{
|
||||
"backbones": backbones,
|
||||
"pools": pools,
|
||||
"linears": linears,
|
||||
"noise_pred_net": noise_pred_net,
|
||||
}
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
nets = nets.float().cuda()
|
||||
ENABLE_EMA = True
|
||||
if ENABLE_EMA:
|
||||
ema = EMAModel(model=nets, power=self.ema_power)
|
||||
else:
|
||||
ema = None
|
||||
self.nets = nets
|
||||
self.ema = ema
|
||||
|
||||
# 设置噪声调度器
|
||||
self.noise_scheduler = DDIMScheduler(
|
||||
num_train_timesteps=50,
|
||||
beta_schedule="squaredcos_cap_v2",
|
||||
clip_sample=True,
|
||||
set_alpha_to_one=True,
|
||||
steps_offset=0,
|
||||
prediction_type="epsilon",
|
||||
)
|
||||
|
||||
n_parameters = sum(p.numel() for p in self.parameters())
|
||||
logger.info("number of parameters: %.2fM", n_parameters / 1e6)
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""
|
||||
配置优化器
|
||||
|
||||
Returns:
|
||||
optimizer: 配置的优化器
|
||||
"""
|
||||
optimizer = torch.optim.AdamW(
|
||||
self.nets.parameters(), lr=self.lr, weight_decay=self.weight_decay
|
||||
)
|
||||
return optimizer
|
||||
|
||||
def __call__(self, qpos, image, actions=None, is_pad=None):
|
||||
"""
|
||||
前向传播函数
|
||||
|
||||
Args:
|
||||
qpos (torch.Tensor): 位置张量
|
||||
image (torch.Tensor): 图像张量
|
||||
actions (torch.Tensor, optional): 动作张量
|
||||
is_pad (torch.Tensor, optional): 填充张量
|
||||
|
||||
Returns:
|
||||
dict: 损失字典(训练时)
|
||||
torch.Tensor: 动作张量(推理时)
|
||||
"""
|
||||
B = qpos.shape[0]
|
||||
if actions is not None: # 训练时
|
||||
nets = self.nets
|
||||
all_features = []
|
||||
for cam_id in range(len(self.camera_names)):
|
||||
cam_image = image[:, cam_id]
|
||||
cam_features = nets["policy"]["backbones"][cam_id](cam_image)
|
||||
pool_features = nets["policy"]["pools"][cam_id](cam_features)
|
||||
pool_features = torch.flatten(pool_features, start_dim=1)
|
||||
out_features = nets["policy"]["linears"][cam_id](pool_features)
|
||||
all_features.append(out_features)
|
||||
|
||||
obs_cond = torch.cat(all_features + [qpos], dim=1)
|
||||
|
||||
# 为动作添加噪声
|
||||
noise = torch.randn(actions.shape, device=obs_cond.device)
|
||||
|
||||
# 为每个数据点采样一个扩散迭代
|
||||
timesteps = torch.randint(
|
||||
0,
|
||||
self.noise_scheduler.config.num_train_timesteps,
|
||||
(B,),
|
||||
device=obs_cond.device,
|
||||
).long()
|
||||
|
||||
# 根据每个扩散迭代的噪声幅度向干净动作添加噪声
|
||||
noisy_actions = self.noise_scheduler.add_noise(actions, noise, timesteps)
|
||||
|
||||
# 预测噪声残差
|
||||
noise_pred = nets["policy"]["noise_pred_net"](
|
||||
noisy_actions, timesteps, global_cond=obs_cond
|
||||
)
|
||||
|
||||
# L2损失
|
||||
all_l2 = F.mse_loss(noise_pred, noise, reduction="none")
|
||||
loss = (all_l2 * ~is_pad.unsqueeze(-1)).mean()
|
||||
|
||||
loss_dict = {}
|
||||
loss_dict["l2_loss"] = loss
|
||||
loss_dict["loss"] = loss
|
||||
|
||||
if self.training and self.ema is not None:
|
||||
self.ema.step(nets)
|
||||
return loss_dict
|
||||
else: # 推理时
|
||||
To = self.observation_horizon
|
||||
Ta = self.action_horizon
|
||||
Tp = self.prediction_horizon
|
||||
action_dim = self.ac_dim
|
||||
|
||||
nets = self.nets
|
||||
if self.ema is not None:
|
||||
nets = self.ema.averaged_model
|
||||
|
||||
all_features = []
|
||||
for cam_id in range(len(self.camera_names)):
|
||||
cam_image = image[:, cam_id]
|
||||
cam_features = nets["policy"]["backbones"][cam_id](cam_image)
|
||||
pool_features = nets["policy"]["pools"][cam_id](cam_features)
|
||||
pool_features = torch.flatten(pool_features, start_dim=1)
|
||||
out_features = nets["policy"]["linears"][cam_id](pool_features)
|
||||
all_features.append(out_features)
|
||||
|
||||
obs_cond = torch.cat(all_features + [qpos], dim=1)
|
||||
|
||||
# 从高斯噪声初始化动作
|
||||
noisy_action = torch.randn((B, Tp, action_dim), device=obs_cond.device)
|
||||
naction = noisy_action
|
||||
|
||||
# 初始化调度器
|
||||
self.noise_scheduler.set_timesteps(self.num_inference_timesteps)
|
||||
|
||||
for k in self.noise_scheduler.timesteps:
|
||||
# 预测噪声
|
||||
noise_pred = nets["policy"]["noise_pred_net"](
|
||||
sample=naction, timestep=k, global_cond=obs_cond
|
||||
)
|
||||
|
||||
# 逆扩散步骤(去除噪声)
|
||||
naction = self.noise_scheduler.step(
|
||||
model_output=noise_pred, timestep=k, sample=naction
|
||||
).prev_sample
|
||||
|
||||
return naction
|
||||
|
||||
def serialize(self):
|
||||
"""
|
||||
序列化模型
|
||||
|
||||
Returns:
|
||||
dict: 模型状态字典
|
||||
"""
|
||||
return {
|
||||
"nets": self.nets.state_dict(),
|
||||
"ema": (
|
||||
self.ema.averaged_model.state_dict() if self.ema is not None else None
|
||||
),
|
||||
}
|
||||
|
||||
def deserialize(self, model_dict):
|
||||
"""
|
||||
反序列化模型
|
||||
|
||||
Args:
|
||||
model_dict (dict): 模型状态字典
|
||||
|
||||
Returns:
|
||||
status: 加载状态
|
||||
"""
|
||||
status = self.nets.load_state_dict(model_dict["nets"])
|
||||
logger.info("Loaded model")
|
||||
if model_dict.get("ema", None) is not None:
|
||||
logger.info("Loaded EMA")
|
||||
status_ema = self.ema.averaged_model.load_state_dict(model_dict["ema"])
|
||||
status = [status, status_ema]
|
||||
return status
|
||||
|
||||
|
||||
class ACTPolicy(nn.Module):
|
||||
def __init__(self, act_config):
|
||||
"""
|
||||
初始化ACTPolicy类
|
||||
|
||||
Args:
|
||||
args_override (dict): 参数覆盖字典
|
||||
"""
|
||||
super().__init__()
|
||||
lr_backbone = act_config["lr_backbone"]
|
||||
vq = act_config["vq"]
|
||||
lr = act_config["lr"]
|
||||
weight_decay = act_config["weight_decay"]
|
||||
|
||||
model = build_vae(
|
||||
act_config["hidden_dim"],
|
||||
act_config["state_dim"],
|
||||
act_config["position_embedding"],
|
||||
lr_backbone,
|
||||
act_config["masks"],
|
||||
act_config["backbone"],
|
||||
act_config["dilation"],
|
||||
act_config["dropout"],
|
||||
act_config["nheads"],
|
||||
act_config["dim_feedforward"],
|
||||
act_config["enc_layers"],
|
||||
act_config["dec_layers"],
|
||||
act_config["pre_norm"],
|
||||
act_config["num_queries"],
|
||||
act_config["camera_names"],
|
||||
vq,
|
||||
act_config["vq_class"],
|
||||
act_config["vq_dim"],
|
||||
act_config["action_dim"],
|
||||
act_config["no_encoder"],
|
||||
)
|
||||
model.cuda()
|
||||
|
||||
param_dicts = [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if "backbone" not in n and p.requires_grad
|
||||
]
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if "backbone" in n and p.requires_grad
|
||||
],
|
||||
"lr": lr_backbone,
|
||||
},
|
||||
]
|
||||
self.optimizer = torch.optim.AdamW(
|
||||
param_dicts, lr=lr, weight_decay=weight_decay
|
||||
)
|
||||
self.model = model # CVAE解码器
|
||||
self.kl_weight = act_config["kl_weight"]
|
||||
self.vq = vq
|
||||
logger.info(f"KL Weight {self.kl_weight}")
|
||||
|
||||
def __call__(self, qpos, image, actions=None, is_pad=None, vq_sample=None):
|
||||
"""
|
||||
前向传播函数
|
||||
|
||||
Args:
|
||||
qpos (torch.Tensor): 角度张量
|
||||
image (torch.Tensor): 图像张量
|
||||
actions (torch.Tensor, optional): 动作张量
|
||||
is_pad (torch.Tensor, optional): 填充张量
|
||||
vq_sample (torch.Tensor, optional): VQ样本
|
||||
|
||||
Returns:
|
||||
dict: 损失字典(训练时)
|
||||
torch.Tensor: 动作张量(推理时)
|
||||
"""
|
||||
env_state = None
|
||||
normalize = transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
)
|
||||
image = normalize(image)
|
||||
if actions is not None: # 训练时
|
||||
actions = actions[:, : self.model.num_queries]
|
||||
is_pad = is_pad[:, : self.model.num_queries]
|
||||
|
||||
loss_dict = dict()
|
||||
a_hat, is_pad_hat, (mu, logvar), probs, binaries = self.model(
|
||||
qpos, image, env_state, actions, is_pad
|
||||
)
|
||||
if self.vq or self.model.encoder is None:
|
||||
total_kld = [torch.tensor(0.0)]
|
||||
else:
|
||||
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
|
||||
if self.vq:
|
||||
loss_dict["vq_discrepancy"] = F.l1_loss(
|
||||
probs, binaries, reduction="mean"
|
||||
)
|
||||
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
|
||||
l1 = (all_l1 * ~is_pad.unsqueeze(-1)).mean()
|
||||
loss_dict["l1"] = l1
|
||||
loss_dict["kl"] = total_kld[0]
|
||||
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
|
||||
return loss_dict
|
||||
else: # 推理时
|
||||
a_hat, _, (_, _), _, _ = self.model(
|
||||
qpos, image, env_state, vq_sample=vq_sample
|
||||
) # no action, sample from prior
|
||||
return a_hat
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""
|
||||
配置优化器
|
||||
|
||||
Returns:
|
||||
optimizer: 配置的优化器
|
||||
"""
|
||||
return self.optimizer
|
||||
|
||||
@torch.no_grad()
|
||||
def vq_encode(self, qpos, actions, is_pad):
|
||||
"""
|
||||
VQ编码
|
||||
|
||||
Args:
|
||||
qpos (torch.Tensor): 位置张量
|
||||
actions (torch.Tensor): 动作张量
|
||||
is_pad (torch.Tensor): 填充张量
|
||||
|
||||
Returns:
|
||||
torch.Tensor: 二进制编码
|
||||
"""
|
||||
actions = actions[:, : self.model.num_queries]
|
||||
is_pad = is_pad[:, : self.model.num_queries]
|
||||
|
||||
_, _, binaries, _, _ = self.model.encode(qpos, actions, is_pad)
|
||||
|
||||
return binaries
|
||||
|
||||
def serialize(self):
|
||||
"""
|
||||
序列化模型
|
||||
|
||||
Returns:
|
||||
dict: 模型状态字典
|
||||
"""
|
||||
return self.state_dict()
|
||||
|
||||
def deserialize(self, model_dict):
|
||||
"""
|
||||
反序列化模型
|
||||
|
||||
Args:
|
||||
model_dict (dict): 模型状态字典
|
||||
|
||||
Returns:
|
||||
status: 加载状态
|
||||
"""
|
||||
return self.load_state_dict(model_dict)
|
||||
|
||||
|
||||
class CNNMLPPolicy(nn.Module):
|
||||
def __init__(self, args_override):
|
||||
"""
|
||||
初始化CNNMLPPolicy类
|
||||
|
||||
Args:
|
||||
args_override (dict): 参数覆盖字典
|
||||
"""
|
||||
super().__init__()
|
||||
# parser = argparse.ArgumentParser(
|
||||
# "DETR training and evaluation script", parents=[get_args_parser()]
|
||||
# )
|
||||
# args = parser.parse_args()
|
||||
|
||||
# for k, v in args_override.items():
|
||||
# setattr(args, k, v)
|
||||
|
||||
model = build_cnnmlp(args_override)
|
||||
model.cuda()
|
||||
|
||||
param_dicts = [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if "backbone" not in n and p.requires_grad
|
||||
]
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in model.named_parameters()
|
||||
if "backbone" in n and p.requires_grad
|
||||
],
|
||||
"lr": args_override.lr_backbone,
|
||||
},
|
||||
]
|
||||
self.model = model # 解码器
|
||||
self.optimizer = torch.optim.AdamW(
|
||||
param_dicts, lr=args_override.lr, weight_decay=args_override.weight_decay
|
||||
)
|
||||
|
||||
def __call__(self, qpos, image, actions=None, is_pad=None):
|
||||
"""
|
||||
前向传播函数
|
||||
|
||||
Args:
|
||||
qpos (torch.Tensor): 位置张量
|
||||
image (torch.Tensor): 图像张量
|
||||
actions (torch.Tensor, optional): 动作张量
|
||||
is_pad (torch.Tensor, optional): 填充张量
|
||||
|
||||
Returns:
|
||||
dict: 损失字典(训练时)
|
||||
torch.Tensor: 动作张量(推理时)
|
||||
"""
|
||||
env_state = None
|
||||
normalize = transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||
)
|
||||
image = normalize(image)
|
||||
if actions is not None: # 训练时
|
||||
actions = actions[:, 0]
|
||||
a_hat = self.model(qpos, image, env_state, actions)
|
||||
mse = F.mse_loss(actions, a_hat)
|
||||
loss_dict = dict()
|
||||
loss_dict["mse"] = mse
|
||||
loss_dict["loss"] = loss_dict["mse"]
|
||||
return loss_dict
|
||||
else: # 推理时
|
||||
a_hat = self.model(qpos, image, env_state) # 无动作,从先验中采样
|
||||
return a_hat
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""
|
||||
配置优化器
|
||||
|
||||
Returns:
|
||||
optimizer: 配置的优化器
|
||||
"""
|
||||
return self.optimizer
|
||||
|
||||
|
||||
def kl_divergence(mu, logvar):
|
||||
"""
|
||||
计算KL散度
|
||||
|
||||
Args:
|
||||
mu (torch.Tensor): 均值张量
|
||||
logvar (torch.Tensor): 对数方差张量
|
||||
|
||||
Returns:
|
||||
tuple: 总KL散度,维度KL散度,均值KL散度
|
||||
"""
|
||||
batch_size = mu.size(0)
|
||||
assert batch_size != 0
|
||||
if mu.data.ndimension() == 4:
|
||||
mu = mu.view(mu.size(0), mu.size(1))
|
||||
if logvar.data.ndimension() == 4:
|
||||
logvar = logvar.view(logvar.size(0), logvar.size(1))
|
||||
|
||||
klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
|
||||
total_kld = klds.sum(1).mean(0, True)
|
||||
dimension_wise_kld = klds.mean(0)
|
||||
mean_kld = klds.mean(1).mean(0, True)
|
||||
|
||||
return total_kld, dimension_wise_kld, mean_kld
|
||||
@@ -0,0 +1,245 @@
|
||||
import os
|
||||
import yaml
|
||||
import pickle
|
||||
import torch
|
||||
# import wandb
|
||||
import logging
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from copy import deepcopy
|
||||
from itertools import repeat
|
||||
from shadow_act.utils.utils import (
|
||||
set_seed,
|
||||
load_data,
|
||||
compute_dict_mean,
|
||||
)
|
||||
from shadow_act.network.policy import ACTPolicy, CNNMLPPolicy, DiffusionPolicy
|
||||
from shadow_act.eval.rm_act_eval import RmActEvaluator
|
||||
|
||||
# 配置日志记录
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
class RmActTrainer:
|
||||
def __init__(self, config):
|
||||
"""
|
||||
初始化训练器,设置随机种子,加载数据,保存数据统计信息。
|
||||
"""
|
||||
self._config = config
|
||||
self._num_steps = config["num_steps"]
|
||||
self._ckpt_dir = config["checkpoint_dir"]
|
||||
self._state_dim = config["state_dim"]
|
||||
self._real_robot = config["real_robot"]
|
||||
self._policy_class = config["policy_class"]
|
||||
self._onscreen_render = config["onscreen_render"]
|
||||
self._policy_config = config["policy_config"]
|
||||
self._camera_names = config["camera_names"]
|
||||
self._max_timesteps = config["episode_len"]
|
||||
self._task_name = config["task_name"]
|
||||
self._temporal_agg = config["temporal_agg"]
|
||||
self._onscreen_cam = "angle"
|
||||
self._vq = config["policy_config"]["vq"]
|
||||
self._batch_size = config["batch_size"]
|
||||
|
||||
self._seed = config["seed"]
|
||||
self._eval_every = config["eval_every"]
|
||||
self._validate_every = config["validate_every"]
|
||||
self._save_every = config["save_every"]
|
||||
self._load_pretrain = config["load_pretrain"]
|
||||
self._resume_ckpt_path = config["resume_ckpt_path"]
|
||||
|
||||
if config["name_filter"] is None:
|
||||
name_filter = lambda n : True
|
||||
else:
|
||||
name_filter = config["name_filter"]
|
||||
|
||||
self._eval = RmActEvaluator(config, True, 50)
|
||||
# 加载数据
|
||||
self._train_dataloader, self._val_dataloader, self._stats, _ = load_data(
|
||||
config["dataset_dir"],
|
||||
name_filter,
|
||||
self._camera_names,
|
||||
self._batch_size,
|
||||
self._batch_size,
|
||||
config["chunk_size"],
|
||||
config["skip_mirrored_data"],
|
||||
self._load_pretrain,
|
||||
self._policy_class,
|
||||
config["stats_dir"],
|
||||
config["sample_weights"],
|
||||
config["train_ratio"],
|
||||
)
|
||||
# 保存数据统计信息
|
||||
stats_path = os.path.join(self._ckpt_dir, "dataset_stats.pkl")
|
||||
with open(stats_path, "wb") as f:
|
||||
pickle.dump(self._stats, f)
|
||||
expr_name = self._ckpt_dir.split("/")[-1]
|
||||
|
||||
# wandb.init(
|
||||
# project="train_rm_aloha", reinit=True, entity="train_rm_aloha", name=expr_name
|
||||
# )
|
||||
|
||||
|
||||
def _make_policy(self):
|
||||
"""
|
||||
根据策略类和配置创建策略对象
|
||||
"""
|
||||
if self._policy_class == "ACT":
|
||||
return ACTPolicy(self._policy_config)
|
||||
elif self._policy_class == "CNNMLP":
|
||||
return CNNMLPPolicy(self._policy_config)
|
||||
elif self._policy_class == "Diffusion":
|
||||
return DiffusionPolicy(self._policy_config)
|
||||
else:
|
||||
raise NotImplementedError(f"Policy class {self._policy_class} is not implemented")
|
||||
|
||||
def _make_optimizer(self):
|
||||
"""
|
||||
根据策略类创建优化器
|
||||
"""
|
||||
if self._policy_class in ["ACT", "CNNMLP", "Diffusion"]:
|
||||
return self._policy.configure_optimizers()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _forward_pass(self, data):
|
||||
"""
|
||||
前向传播,计算损失
|
||||
"""
|
||||
image_data, qpos_data, action_data, is_pad = data
|
||||
try:
|
||||
image_data, qpos_data, action_data, is_pad = (
|
||||
image_data.cuda(),
|
||||
qpos_data.cuda(),
|
||||
action_data.cuda(),
|
||||
is_pad.cuda(),
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logging.error(f"CUDA error: {e}")
|
||||
raise
|
||||
return self._policy(qpos_data, image_data, action_data, is_pad)
|
||||
|
||||
def _repeater(self):
|
||||
"""
|
||||
数据加载器的重复器,生成数据
|
||||
"""
|
||||
epoch = 0
|
||||
for loader in repeat(self._train_dataloader):
|
||||
for data in loader:
|
||||
yield data
|
||||
logging.info(f"Epoch {epoch} done")
|
||||
epoch += 1
|
||||
|
||||
def train(self):
|
||||
"""
|
||||
训练模型,保存最佳模型
|
||||
"""
|
||||
set_seed(self._seed)
|
||||
self._policy = self._make_policy()
|
||||
min_val_loss = np.inf
|
||||
best_ckpt_info = None
|
||||
|
||||
# 加载预训练模型
|
||||
if self._load_pretrain:
|
||||
try:
|
||||
loading_status = self._policy.deserialize(
|
||||
torch.load(
|
||||
os.path.join(
|
||||
"/home/zfu/interbotix_ws/src/act/ckpts/pretrain_all",
|
||||
"policy_step_50000_seed_0.ckpt",
|
||||
)
|
||||
)
|
||||
)
|
||||
logging.info(f"loaded! {loading_status}")
|
||||
except FileNotFoundError as e:
|
||||
logging.error(f"Pretrain model not found: {e}")
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading pretrain model: {e}")
|
||||
|
||||
# 恢复检查点
|
||||
if self._resume_ckpt_path is not None:
|
||||
try:
|
||||
loading_status = self._policy.deserialize(torch.load(self._resume_ckpt_path))
|
||||
logging.info(f"Resume policy from: {self._resume_ckpt_path}, Status: {loading_status}")
|
||||
except FileNotFoundError as e:
|
||||
logging.error(f"Checkpoint not found: {e}")
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading checkpoint: {e}")
|
||||
|
||||
self._policy.cuda()
|
||||
|
||||
self._optimizer = self._make_optimizer()
|
||||
train_dataloader = self._repeater() # 重复器
|
||||
|
||||
for step in tqdm(range(self._num_steps + 1)):
|
||||
# 验证模型
|
||||
if step % self._validate_every != 0:
|
||||
continue
|
||||
logging.info("validating")
|
||||
with torch.inference_mode():
|
||||
self._policy.eval()
|
||||
validation_dicts = []
|
||||
for batch_idx, data in enumerate(self._val_dataloader):
|
||||
forward_dict = self._forward_pass(data) # forward_dict = {"loss": loss, "kl": kl, "mse": mse}
|
||||
validation_dicts.append(forward_dict)
|
||||
if batch_idx > 50: # 限制验证批次数 TODO 确定批次关联
|
||||
break
|
||||
|
||||
validation_summary = compute_dict_mean(validation_dicts)
|
||||
epoch_val_loss = validation_summary["loss"]
|
||||
if epoch_val_loss < min_val_loss:
|
||||
min_val_loss = epoch_val_loss
|
||||
best_ckpt_info = (
|
||||
step,
|
||||
min_val_loss,
|
||||
deepcopy(self._policy.serialize()),
|
||||
)
|
||||
|
||||
# wandb记录验证结果
|
||||
# for k in list(validation_summary.keys()):
|
||||
# validation_summary[f"val_{k}"] = validation_summary.pop(k)
|
||||
|
||||
# wandb.log(validation_summary, step=step)
|
||||
logging.info(f"Val loss: {epoch_val_loss:.5f}")
|
||||
summary_string = " ".join(f"{k}: {v.item():.3f}" for k, v in validation_summary.items())
|
||||
logging.info(summary_string)
|
||||
|
||||
# 评估模型
|
||||
# if (step > 0) and (step % self._eval_every == 0):
|
||||
# ckpt_name = f"policy_step_{step}_seed_{self._seed}.ckpt"
|
||||
# ckpt_path = os.path.join(self._ckpt_dir, ckpt_name)
|
||||
# torch.save(self._policy.serialize(), ckpt_path)
|
||||
# success, _ = self._eval.evaluate(ckpt_name)
|
||||
# wandb.log({"success": success}, step=step)
|
||||
|
||||
# 训练模型
|
||||
self._policy.train()
|
||||
self._optimizer.zero_grad()
|
||||
data = next(train_dataloader)
|
||||
forward_dict = self._forward_pass(data)
|
||||
loss = forward_dict["loss"]
|
||||
loss.backward()
|
||||
self._optimizer.step()
|
||||
# wandb.log(forward_dict, step=step)
|
||||
|
||||
# 保存检查点
|
||||
if step % self._save_every == 0:
|
||||
ckpt_path = os.path.join(self._ckpt_dir, f"policy_step_{step}_seed_{self._seed}.ckpt")
|
||||
torch.save(self._policy.serialize(), ckpt_path)
|
||||
|
||||
# 保存最后的模型
|
||||
ckpt_path = os.path.join(self._ckpt_dir, "policy_last.ckpt")
|
||||
torch.save(self._policy.serialize(), ckpt_path)
|
||||
|
||||
best_step, min_val_loss, best_state_dict = best_ckpt_info
|
||||
ckpt_path = os.path.join(self._ckpt_dir, f"policy_step_{best_step}_seed_{self._seed}.ckpt")
|
||||
torch.save(best_state_dict, ckpt_path)
|
||||
logging.info(f"Training finished:\nSeed {self._seed}, val loss {min_val_loss:.6f} at step {best_step}")
|
||||
|
||||
return best_ckpt_info
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with open("/home/rm/aloha/shadow_rm_act/config/config.yaml") as f:
|
||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
trainer = RmActTrainer(config)
|
||||
trainer.train()
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = '0.1.0'
|
||||
@@ -0,0 +1,88 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
Utilities for bounding box manipulation and GIoU.
|
||||
"""
|
||||
import torch
|
||||
from torchvision.ops.boxes import box_area
|
||||
|
||||
|
||||
def box_cxcywh_to_xyxy(x):
|
||||
x_c, y_c, w, h = x.unbind(-1)
|
||||
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
|
||||
(x_c + 0.5 * w), (y_c + 0.5 * h)]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
def box_xyxy_to_cxcywh(x):
|
||||
x0, y0, x1, y1 = x.unbind(-1)
|
||||
b = [(x0 + x1) / 2, (y0 + y1) / 2,
|
||||
(x1 - x0), (y1 - y0)]
|
||||
return torch.stack(b, dim=-1)
|
||||
|
||||
|
||||
# modified from torchvision to also return the union
|
||||
def box_iou(boxes1, boxes2):
|
||||
area1 = box_area(boxes1)
|
||||
area2 = box_area(boxes2)
|
||||
|
||||
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
||||
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
||||
|
||||
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
||||
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
|
||||
|
||||
union = area1[:, None] + area2 - inter
|
||||
|
||||
iou = inter / union
|
||||
return iou, union
|
||||
|
||||
|
||||
def generalized_box_iou(boxes1, boxes2):
|
||||
"""
|
||||
Generalized IoU from https://giou.stanford.edu/
|
||||
|
||||
The boxes should be in [x0, y0, x1, y1] format
|
||||
|
||||
Returns a [N, M] pairwise matrix, where N = len(boxes1)
|
||||
and M = len(boxes2)
|
||||
"""
|
||||
# degenerate boxes gives inf / nan results
|
||||
# so do an early check
|
||||
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
|
||||
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
|
||||
iou, union = box_iou(boxes1, boxes2)
|
||||
|
||||
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
||||
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
||||
|
||||
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
||||
area = wh[:, :, 0] * wh[:, :, 1]
|
||||
|
||||
return iou - (area - union) / area
|
||||
|
||||
|
||||
def masks_to_boxes(masks):
|
||||
"""Compute the bounding boxes around the provided masks
|
||||
|
||||
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
|
||||
|
||||
Returns a [N, 4] tensors, with the boxes in xyxy format
|
||||
"""
|
||||
if masks.numel() == 0:
|
||||
return torch.zeros((0, 4), device=masks.device)
|
||||
|
||||
h, w = masks.shape[-2:]
|
||||
|
||||
y = torch.arange(0, h, dtype=torch.float)
|
||||
x = torch.arange(0, w, dtype=torch.float)
|
||||
y, x = torch.meshgrid(y, x)
|
||||
|
||||
x_mask = (masks * x.unsqueeze(0))
|
||||
x_max = x_mask.flatten(1).max(-1)[0]
|
||||
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
||||
|
||||
y_mask = (masks * y.unsqueeze(0))
|
||||
y_max = y_mask.flatten(1).max(-1)[0]
|
||||
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
||||
|
||||
return torch.stack([x_min, y_min, x_max, y_max], 1)
|
||||
@@ -0,0 +1,468 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
"""
|
||||
Misc functions, including distributed helpers.
|
||||
|
||||
Mostly copy-paste from torchvision references.
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
import datetime
|
||||
import pickle
|
||||
from packaging import version
|
||||
from typing import Optional, List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
# needed due to empty tensor bug in pytorch and torchvision 0.5
|
||||
import torchvision
|
||||
if version.parse(torchvision.__version__) < version.parse('0.7'):
|
||||
from torchvision.ops import _new_empty_tensor
|
||||
from torchvision.ops.misc import _output_size
|
||||
|
||||
|
||||
class SmoothedValue(object):
|
||||
"""Track a series of values and provide access to smoothed values over a
|
||||
window or the global series average.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size=20, fmt=None):
|
||||
if fmt is None:
|
||||
fmt = "{median:.4f} ({global_avg:.4f})"
|
||||
self.deque = deque(maxlen=window_size)
|
||||
self.total = 0.0
|
||||
self.count = 0
|
||||
self.fmt = fmt
|
||||
|
||||
def update(self, value, n=1):
|
||||
self.deque.append(value)
|
||||
self.count += n
|
||||
self.total += value * n
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
"""
|
||||
Warning: does not synchronize the deque!
|
||||
"""
|
||||
if not is_dist_avail_and_initialized():
|
||||
return
|
||||
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
||||
dist.barrier()
|
||||
dist.all_reduce(t)
|
||||
t = t.tolist()
|
||||
self.count = int(t[0])
|
||||
self.total = t[1]
|
||||
|
||||
@property
|
||||
def median(self):
|
||||
d = torch.tensor(list(self.deque))
|
||||
return d.median().item()
|
||||
|
||||
@property
|
||||
def avg(self):
|
||||
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||
return d.mean().item()
|
||||
|
||||
@property
|
||||
def global_avg(self):
|
||||
return self.total / self.count
|
||||
|
||||
@property
|
||||
def max(self):
|
||||
return max(self.deque)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.deque[-1]
|
||||
|
||||
def __str__(self):
|
||||
return self.fmt.format(
|
||||
median=self.median,
|
||||
avg=self.avg,
|
||||
global_avg=self.global_avg,
|
||||
max=self.max,
|
||||
value=self.value)
|
||||
|
||||
|
||||
def all_gather(data):
|
||||
"""
|
||||
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
||||
Args:
|
||||
data: any picklable object
|
||||
Returns:
|
||||
list[data]: list of data gathered from each rank
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size == 1:
|
||||
return [data]
|
||||
|
||||
# serialized to a Tensor
|
||||
buffer = pickle.dumps(data)
|
||||
storage = torch.ByteStorage.from_buffer(buffer)
|
||||
tensor = torch.ByteTensor(storage).to("cuda")
|
||||
|
||||
# obtain Tensor size of each rank
|
||||
local_size = torch.tensor([tensor.numel()], device="cuda")
|
||||
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
|
||||
dist.all_gather(size_list, local_size)
|
||||
size_list = [int(size.item()) for size in size_list]
|
||||
max_size = max(size_list)
|
||||
|
||||
# receiving Tensor from all ranks
|
||||
# we pad the tensor because torch all_gather does not support
|
||||
# gathering tensors of different shapes
|
||||
tensor_list = []
|
||||
for _ in size_list:
|
||||
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
|
||||
if local_size != max_size:
|
||||
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
|
||||
tensor = torch.cat((tensor, padding), dim=0)
|
||||
dist.all_gather(tensor_list, tensor)
|
||||
|
||||
data_list = []
|
||||
for size, tensor in zip(size_list, tensor_list):
|
||||
buffer = tensor.cpu().numpy().tobytes()[:size]
|
||||
data_list.append(pickle.loads(buffer))
|
||||
|
||||
return data_list
|
||||
|
||||
|
||||
def reduce_dict(input_dict, average=True):
|
||||
"""
|
||||
Args:
|
||||
input_dict (dict): all the values will be reduced
|
||||
average (bool): whether to do average or sum
|
||||
Reduce the values in the dictionary from all processes so that all processes
|
||||
have the averaged results. Returns a dict with the same fields as
|
||||
input_dict, after reduction.
|
||||
"""
|
||||
world_size = get_world_size()
|
||||
if world_size < 2:
|
||||
return input_dict
|
||||
with torch.no_grad():
|
||||
names = []
|
||||
values = []
|
||||
# sort the keys so that they are consistent across processes
|
||||
for k in sorted(input_dict.keys()):
|
||||
names.append(k)
|
||||
values.append(input_dict[k])
|
||||
values = torch.stack(values, dim=0)
|
||||
dist.all_reduce(values)
|
||||
if average:
|
||||
values /= world_size
|
||||
reduced_dict = {k: v for k, v in zip(names, values)}
|
||||
return reduced_dict
|
||||
|
||||
|
||||
class MetricLogger(object):
|
||||
def __init__(self, delimiter="\t"):
|
||||
self.meters = defaultdict(SmoothedValue)
|
||||
self.delimiter = delimiter
|
||||
|
||||
def update(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.item()
|
||||
assert isinstance(v, (float, int))
|
||||
self.meters[k].update(v)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in self.meters:
|
||||
return self.meters[attr]
|
||||
if attr in self.__dict__:
|
||||
return self.__dict__[attr]
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(
|
||||
type(self).__name__, attr))
|
||||
|
||||
def __str__(self):
|
||||
loss_str = []
|
||||
for name, meter in self.meters.items():
|
||||
loss_str.append(
|
||||
"{}: {}".format(name, str(meter))
|
||||
)
|
||||
return self.delimiter.join(loss_str)
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
for meter in self.meters.values():
|
||||
meter.synchronize_between_processes()
|
||||
|
||||
def add_meter(self, name, meter):
|
||||
self.meters[name] = meter
|
||||
|
||||
def log_every(self, iterable, print_freq, header=None):
|
||||
i = 0
|
||||
if not header:
|
||||
header = ''
|
||||
start_time = time.time()
|
||||
end = time.time()
|
||||
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
||||
data_time = SmoothedValue(fmt='{avg:.4f}')
|
||||
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
||||
if torch.cuda.is_available():
|
||||
log_msg = self.delimiter.join([
|
||||
header,
|
||||
'[{0' + space_fmt + '}/{1}]',
|
||||
'eta: {eta}',
|
||||
'{meters}',
|
||||
'time: {time}',
|
||||
'data: {data}',
|
||||
'max mem: {memory:.0f}'
|
||||
])
|
||||
else:
|
||||
log_msg = self.delimiter.join([
|
||||
header,
|
||||
'[{0' + space_fmt + '}/{1}]',
|
||||
'eta: {eta}',
|
||||
'{meters}',
|
||||
'time: {time}',
|
||||
'data: {data}'
|
||||
])
|
||||
MB = 1024.0 * 1024.0
|
||||
for obj in iterable:
|
||||
data_time.update(time.time() - end)
|
||||
yield obj
|
||||
iter_time.update(time.time() - end)
|
||||
if i % print_freq == 0 or i == len(iterable) - 1:
|
||||
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
if torch.cuda.is_available():
|
||||
print(log_msg.format(
|
||||
i, len(iterable), eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time), data=str(data_time),
|
||||
memory=torch.cuda.max_memory_allocated() / MB))
|
||||
else:
|
||||
print(log_msg.format(
|
||||
i, len(iterable), eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time), data=str(data_time)))
|
||||
i += 1
|
||||
end = time.time()
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print('{} Total time: {} ({:.4f} s / it)'.format(
|
||||
header, total_time_str, total_time / len(iterable)))
|
||||
|
||||
|
||||
def get_sha():
|
||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
def _run(command):
|
||||
return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
|
||||
sha = 'N/A'
|
||||
diff = "clean"
|
||||
branch = 'N/A'
|
||||
try:
|
||||
sha = _run(['git', 'rev-parse', 'HEAD'])
|
||||
subprocess.check_output(['git', 'diff'], cwd=cwd)
|
||||
diff = _run(['git', 'diff-index', 'HEAD'])
|
||||
diff = "has uncommited changes" if diff else "clean"
|
||||
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
|
||||
except Exception:
|
||||
pass
|
||||
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
||||
return message
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
batch = list(zip(*batch))
|
||||
batch[0] = nested_tensor_from_tensor_list(batch[0])
|
||||
return tuple(batch)
|
||||
|
||||
|
||||
def _max_by_axis(the_list):
|
||||
# type: (List[List[int]]) -> List[int]
|
||||
maxes = the_list[0]
|
||||
for sublist in the_list[1:]:
|
||||
for index, item in enumerate(sublist):
|
||||
maxes[index] = max(maxes[index], item)
|
||||
return maxes
|
||||
|
||||
|
||||
class NestedTensor(object):
|
||||
def __init__(self, tensors, mask: Optional[Tensor]):
|
||||
self.tensors = tensors
|
||||
self.mask = mask
|
||||
|
||||
def to(self, device):
|
||||
# type: (Device) -> NestedTensor # noqa
|
||||
cast_tensor = self.tensors.to(device)
|
||||
mask = self.mask
|
||||
if mask is not None:
|
||||
assert mask is not None
|
||||
cast_mask = mask.to(device)
|
||||
else:
|
||||
cast_mask = None
|
||||
return NestedTensor(cast_tensor, cast_mask)
|
||||
|
||||
def decompose(self):
|
||||
return self.tensors, self.mask
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.tensors)
|
||||
|
||||
|
||||
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
||||
# TODO make this more general
|
||||
if tensor_list[0].ndim == 3:
|
||||
if torchvision._is_tracing():
|
||||
# nested_tensor_from_tensor_list() does not export well to ONNX
|
||||
# call _onnx_nested_tensor_from_tensor_list() instead
|
||||
return _onnx_nested_tensor_from_tensor_list(tensor_list)
|
||||
|
||||
# TODO make it support different-sized images
|
||||
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
||||
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
|
||||
batch_shape = [len(tensor_list)] + max_size
|
||||
b, c, h, w = batch_shape
|
||||
dtype = tensor_list[0].dtype
|
||||
device = tensor_list[0].device
|
||||
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
||||
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
||||
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
||||
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||
m[: img.shape[1], :img.shape[2]] = False
|
||||
else:
|
||||
raise ValueError('not supported')
|
||||
return NestedTensor(tensor, mask)
|
||||
|
||||
|
||||
# _onnx_nested_tensor_from_tensor_list() is an implementation of
|
||||
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
|
||||
@torch.jit.unused
|
||||
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
|
||||
max_size = []
|
||||
for i in range(tensor_list[0].dim()):
|
||||
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
|
||||
max_size.append(max_size_i)
|
||||
max_size = tuple(max_size)
|
||||
|
||||
# work around for
|
||||
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||
# m[: img.shape[1], :img.shape[2]] = False
|
||||
# which is not yet supported in onnx
|
||||
padded_imgs = []
|
||||
padded_masks = []
|
||||
for img in tensor_list:
|
||||
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
|
||||
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
|
||||
padded_imgs.append(padded_img)
|
||||
|
||||
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
|
||||
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
|
||||
padded_masks.append(padded_mask.to(torch.bool))
|
||||
|
||||
tensor = torch.stack(padded_imgs)
|
||||
mask = torch.stack(padded_masks)
|
||||
|
||||
return NestedTensor(tensor, mask=mask)
|
||||
|
||||
|
||||
def setup_for_distributed(is_master):
|
||||
"""
|
||||
This function disables printing when not in master process
|
||||
"""
|
||||
import builtins as __builtin__
|
||||
builtin_print = __builtin__.print
|
||||
|
||||
def print(*args, **kwargs):
|
||||
force = kwargs.pop('force', False)
|
||||
if is_master or force:
|
||||
builtin_print(*args, **kwargs)
|
||||
|
||||
__builtin__.print = print
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def is_main_process():
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def save_on_master(*args, **kwargs):
|
||||
if is_main_process():
|
||||
torch.save(*args, **kwargs)
|
||||
|
||||
|
||||
def init_distributed_mode(args):
|
||||
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
||||
args.rank = int(os.environ["RANK"])
|
||||
args.world_size = int(os.environ['WORLD_SIZE'])
|
||||
args.gpu = int(os.environ['LOCAL_RANK'])
|
||||
elif 'SLURM_PROCID' in os.environ:
|
||||
args.rank = int(os.environ['SLURM_PROCID'])
|
||||
args.gpu = args.rank % torch.cuda.device_count()
|
||||
else:
|
||||
print('Not using distributed mode')
|
||||
args.distributed = False
|
||||
return
|
||||
|
||||
args.distributed = True
|
||||
|
||||
torch.cuda.set_device(args.gpu)
|
||||
args.dist_backend = 'nccl'
|
||||
print('| distributed init (rank {}): {}'.format(
|
||||
args.rank, args.dist_url), flush=True)
|
||||
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
||||
world_size=args.world_size, rank=args.rank)
|
||||
torch.distributed.barrier()
|
||||
setup_for_distributed(args.rank == 0)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
if target.numel() == 0:
|
||||
return [torch.zeros([], device=output.device)]
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
||||
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
|
||||
"""
|
||||
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
|
||||
This will eventually be supported natively by PyTorch, and this
|
||||
class can go away.
|
||||
"""
|
||||
if version.parse(torchvision.__version__) < version.parse('0.7'):
|
||||
if input.numel() > 0:
|
||||
return torch.nn.functional.interpolate(
|
||||
input, size, scale_factor, mode, align_corners
|
||||
)
|
||||
|
||||
output_shape = _output_size(2, input, size, scale_factor)
|
||||
output_shape = list(input.shape[:-2]) + list(output_shape)
|
||||
return _new_empty_tensor(input, output_shape)
|
||||
else:
|
||||
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
|
||||
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
Plotting utilities to visualize training logs.
|
||||
"""
|
||||
import torch
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from pathlib import Path, PurePath
|
||||
|
||||
|
||||
def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'):
|
||||
'''
|
||||
Function to plot specific fields from training log(s). Plots both training and test results.
|
||||
|
||||
:: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file
|
||||
- fields = which results to plot from each log file - plots both training and test for each field.
|
||||
- ewm_col = optional, which column to use as the exponential weighted smoothing of the plots
|
||||
- log_name = optional, name of log file if different than default 'log.txt'.
|
||||
|
||||
:: Outputs - matplotlib plots of results in fields, color coded for each log file.
|
||||
- solid lines are training results, dashed lines are test results.
|
||||
|
||||
'''
|
||||
func_name = "plot_utils.py::plot_logs"
|
||||
|
||||
# verify logs is a list of Paths (list[Paths]) or single Pathlib object Path,
|
||||
# convert single Path to list to avoid 'not iterable' error
|
||||
|
||||
if not isinstance(logs, list):
|
||||
if isinstance(logs, PurePath):
|
||||
logs = [logs]
|
||||
print(f"{func_name} info: logs param expects a list argument, converted to list[Path].")
|
||||
else:
|
||||
raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \
|
||||
Expect list[Path] or single Path obj, received {type(logs)}")
|
||||
|
||||
# Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir
|
||||
for i, dir in enumerate(logs):
|
||||
if not isinstance(dir, PurePath):
|
||||
raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}")
|
||||
if not dir.exists():
|
||||
raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}")
|
||||
# verify log_name exists
|
||||
fn = Path(dir / log_name)
|
||||
if not fn.exists():
|
||||
print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?")
|
||||
print(f"--> full path of missing log file: {fn}")
|
||||
return
|
||||
|
||||
# load log file(s) and plot
|
||||
dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs]
|
||||
|
||||
fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5))
|
||||
|
||||
for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))):
|
||||
for j, field in enumerate(fields):
|
||||
if field == 'mAP':
|
||||
coco_eval = pd.DataFrame(
|
||||
np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1]
|
||||
).ewm(com=ewm_col).mean()
|
||||
axs[j].plot(coco_eval, c=color)
|
||||
else:
|
||||
df.interpolate().ewm(com=ewm_col).mean().plot(
|
||||
y=[f'train_{field}', f'test_{field}'],
|
||||
ax=axs[j],
|
||||
color=[color] * 2,
|
||||
style=['-', '--']
|
||||
)
|
||||
for ax, field in zip(axs, fields):
|
||||
ax.legend([Path(p).name for p in logs])
|
||||
ax.set_title(field)
|
||||
|
||||
|
||||
def plot_precision_recall(files, naming_scheme='iter'):
|
||||
if naming_scheme == 'exp_id':
|
||||
# name becomes exp_id
|
||||
names = [f.parts[-3] for f in files]
|
||||
elif naming_scheme == 'iter':
|
||||
names = [f.stem for f in files]
|
||||
else:
|
||||
raise ValueError(f'not supported {naming_scheme}')
|
||||
fig, axs = plt.subplots(ncols=2, figsize=(16, 5))
|
||||
for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names):
|
||||
data = torch.load(f)
|
||||
# precision is n_iou, n_points, n_cat, n_area, max_det
|
||||
precision = data['precision']
|
||||
recall = data['params'].recThrs
|
||||
scores = data['scores']
|
||||
# take precision for all classes, all areas and 100 detections
|
||||
precision = precision[0, :, :, 0, -1].mean(1)
|
||||
scores = scores[0, :, :, 0, -1].mean(1)
|
||||
prec = precision.mean()
|
||||
rec = data['recall'][0, :, 0, -1].mean()
|
||||
print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' +
|
||||
f'score={scores.mean():0.3f}, ' +
|
||||
f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}'
|
||||
)
|
||||
axs[0].plot(recall, precision, c=color)
|
||||
axs[1].plot(recall, scores, c=color)
|
||||
|
||||
axs[0].set_title('Precision / Recall')
|
||||
axs[0].legend(names)
|
||||
axs[1].set_title('Scores / Recall')
|
||||
axs[1].legend(names)
|
||||
return fig, axs
|
||||
@@ -0,0 +1,499 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import os
|
||||
import h5py
|
||||
import pickle
|
||||
import fnmatch
|
||||
import cv2
|
||||
from time import time
|
||||
from torch.utils.data import TensorDataset, DataLoader
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
|
||||
|
||||
|
||||
def flatten_list(l):
|
||||
return [item for sublist in l for item in sublist]
|
||||
|
||||
|
||||
class EpisodicDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
dataset_path_list,
|
||||
camera_names,
|
||||
norm_stats,
|
||||
episode_ids,
|
||||
episode_len,
|
||||
chunk_size,
|
||||
policy_class,
|
||||
):
|
||||
super(EpisodicDataset).__init__()
|
||||
self.episode_ids = episode_ids
|
||||
self.dataset_path_list = dataset_path_list
|
||||
self.camera_names = camera_names
|
||||
self.norm_stats = norm_stats
|
||||
self.episode_len = episode_len
|
||||
self.chunk_size = chunk_size
|
||||
self.cumulative_len = np.cumsum(self.episode_len)
|
||||
self.max_episode_len = max(episode_len)
|
||||
self.policy_class = policy_class
|
||||
if self.policy_class == "Diffusion":
|
||||
self.augment_images = True
|
||||
else:
|
||||
self.augment_images = False
|
||||
self.transformations = None
|
||||
self.__getitem__(0) # initialize self.is_sim and self.transformations
|
||||
self.is_sim = False
|
||||
|
||||
# def __len__(self):
|
||||
# return sum(self.episode_len)
|
||||
|
||||
def _locate_transition(self, index):
|
||||
assert index < self.cumulative_len[-1]
|
||||
episode_index = np.argmax(
|
||||
self.cumulative_len > index
|
||||
) # argmax returns first True index
|
||||
start_ts = index - (
|
||||
self.cumulative_len[episode_index] - self.episode_len[episode_index]
|
||||
)
|
||||
episode_id = self.episode_ids[episode_index]
|
||||
return episode_id, start_ts
|
||||
|
||||
def __getitem__(self, index):
|
||||
episode_id, start_ts = self._locate_transition(index)
|
||||
dataset_path = self.dataset_path_list[episode_id]
|
||||
try:
|
||||
# print(dataset_path)
|
||||
with h5py.File(dataset_path, "r") as root:
|
||||
try: # some legacy data does not have this attribute
|
||||
is_sim = root.attrs["sim"]
|
||||
except:
|
||||
is_sim = False
|
||||
compressed = root.attrs.get("compress", False)
|
||||
if "/base_action" in root:
|
||||
base_action = root["/base_action"][()]
|
||||
base_action = preprocess_base_action(base_action)
|
||||
action = np.concatenate([root["/action"][()], base_action], axis=-1)
|
||||
else:
|
||||
# TODO
|
||||
action = root["/action"][()]
|
||||
# dummy_base_action = np.zeros([action.shape[0], 2])
|
||||
# action = np.concatenate([action, dummy_base_action], axis=-1)
|
||||
original_action_shape = action.shape
|
||||
episode_len = original_action_shape[0]
|
||||
# get observation at start_ts only
|
||||
qpos = root["/observations/qpos"][start_ts]
|
||||
qvel = root["/observations/qvel"][start_ts]
|
||||
image_dict = dict()
|
||||
for cam_name in self.camera_names:
|
||||
image_dict[cam_name] = root[f"/observations/images/{cam_name}"][
|
||||
start_ts
|
||||
]
|
||||
|
||||
if compressed:
|
||||
for cam_name in image_dict.keys():
|
||||
decompressed_image = cv2.imdecode(image_dict[cam_name], 1)
|
||||
image_dict[cam_name] = np.array(decompressed_image)
|
||||
|
||||
# get all actions after and including start_ts
|
||||
if is_sim:
|
||||
action = action[start_ts:]
|
||||
action_len = episode_len - start_ts
|
||||
else:
|
||||
action = action[
|
||||
max(0, start_ts - 1) :
|
||||
] # hack, to make timesteps more aligned
|
||||
action_len = episode_len - max(
|
||||
0, start_ts - 1
|
||||
) # hack, to make timesteps more aligned
|
||||
|
||||
# self.is_sim = is_sim
|
||||
padded_action = np.zeros(
|
||||
(self.max_episode_len, original_action_shape[1]), dtype=np.float32
|
||||
)
|
||||
padded_action[:action_len] = action
|
||||
is_pad = np.zeros(self.max_episode_len)
|
||||
is_pad[action_len:] = 1
|
||||
|
||||
padded_action = padded_action[: self.chunk_size]
|
||||
is_pad = is_pad[: self.chunk_size]
|
||||
|
||||
# new axis for different cameras
|
||||
all_cam_images = []
|
||||
for cam_name in self.camera_names:
|
||||
all_cam_images.append(image_dict[cam_name])
|
||||
all_cam_images = np.stack(all_cam_images, axis=0)
|
||||
|
||||
# construct observations
|
||||
image_data = torch.from_numpy(all_cam_images)
|
||||
qpos_data = torch.from_numpy(qpos).float()
|
||||
action_data = torch.from_numpy(padded_action).float()
|
||||
is_pad = torch.from_numpy(is_pad).bool()
|
||||
|
||||
# channel last
|
||||
image_data = torch.einsum("k h w c -> k c h w", image_data)
|
||||
|
||||
# augmentation
|
||||
if self.transformations is None:
|
||||
print("Initializing transformations")
|
||||
original_size = image_data.shape[2:]
|
||||
ratio = 0.95
|
||||
self.transformations = [
|
||||
transforms.RandomCrop(
|
||||
size=[
|
||||
int(original_size[0] * ratio),
|
||||
int(original_size[1] * ratio),
|
||||
]
|
||||
),
|
||||
transforms.Resize(original_size, antialias=True),
|
||||
transforms.RandomRotation(degrees=[-5.0, 5.0], expand=False),
|
||||
transforms.ColorJitter(
|
||||
brightness=0.3, contrast=0.4, saturation=0.5
|
||||
), # , hue=0.08)
|
||||
]
|
||||
|
||||
if self.augment_images:
|
||||
for transform in self.transformations:
|
||||
image_data = transform(image_data)
|
||||
|
||||
# normalize image and change dtype to float
|
||||
image_data = image_data / 255.0
|
||||
|
||||
if self.policy_class == "Diffusion":
|
||||
# normalize to [-1, 1]
|
||||
action_data = (
|
||||
(action_data - self.norm_stats["action_min"])
|
||||
/ (self.norm_stats["action_max"] - self.norm_stats["action_min"])
|
||||
) * 2 - 1
|
||||
else:
|
||||
# normalize to mean 0 std 1
|
||||
action_data = (
|
||||
action_data - self.norm_stats["action_mean"]
|
||||
) / self.norm_stats["action_std"]
|
||||
|
||||
qpos_data = (qpos_data - self.norm_stats["qpos_mean"]) / self.norm_stats[
|
||||
"qpos_std"
|
||||
]
|
||||
|
||||
except:
|
||||
print(f"Error loading {dataset_path} in __getitem__")
|
||||
quit()
|
||||
|
||||
# print(image_data.dtype, qpos_data.dtype, action_data.dtype, is_pad.dtype)
|
||||
return image_data, qpos_data, action_data, is_pad
|
||||
|
||||
|
||||
def get_norm_stats(dataset_path_list):
|
||||
all_qpos_data = []
|
||||
all_action_data = []
|
||||
all_episode_len = []
|
||||
|
||||
for dataset_path in dataset_path_list:
|
||||
try:
|
||||
with h5py.File(dataset_path, "r") as root:
|
||||
qpos = root["/observations/qpos"][()]
|
||||
qvel = root["/observations/qvel"][()]
|
||||
if "/base_action" in root:
|
||||
base_action = root["/base_action"][()]
|
||||
# base_action = preprocess_base_action(base_action)
|
||||
# action = np.concatenate([root["/action"][()], base_action], axis=-1)
|
||||
else:
|
||||
# TODO
|
||||
action = root["/action"][()]
|
||||
# dummy_base_action = np.zeros([action.shape[0], 2])
|
||||
# action = np.concatenate([action, dummy_base_action], axis=-1)
|
||||
except Exception as e:
|
||||
print(f"Error loading {dataset_path} in get_norm_stats")
|
||||
print(e)
|
||||
quit()
|
||||
all_qpos_data.append(torch.from_numpy(qpos))
|
||||
all_action_data.append(torch.from_numpy(action))
|
||||
all_episode_len.append(len(qpos))
|
||||
all_qpos_data = torch.cat(all_qpos_data, dim=0)
|
||||
all_action_data = torch.cat(all_action_data, dim=0)
|
||||
|
||||
# normalize action data
|
||||
action_mean = all_action_data.mean(dim=[0]).float()
|
||||
action_std = all_action_data.std(dim=[0]).float()
|
||||
action_std = torch.clip(action_std, 1e-2, np.inf) # clipping
|
||||
|
||||
# normalize qpos data
|
||||
qpos_mean = all_qpos_data.mean(dim=[0]).float()
|
||||
qpos_std = all_qpos_data.std(dim=[0]).float()
|
||||
qpos_std = torch.clip(qpos_std, 1e-2, np.inf) # clipping
|
||||
|
||||
action_min = all_action_data.min(dim=0).values.float()
|
||||
action_max = all_action_data.max(dim=0).values.float()
|
||||
|
||||
eps = 0.0001
|
||||
stats = {
|
||||
"action_mean": action_mean.numpy(),
|
||||
"action_std": action_std.numpy(),
|
||||
"action_min": action_min.numpy() - eps,
|
||||
"action_max": action_max.numpy() + eps,
|
||||
"qpos_mean": qpos_mean.numpy(),
|
||||
"qpos_std": qpos_std.numpy(),
|
||||
"example_qpos": qpos,
|
||||
}
|
||||
|
||||
return stats, all_episode_len
|
||||
|
||||
|
||||
def find_all_hdf5(dataset_dir, skip_mirrored_data):
|
||||
hdf5_files = []
|
||||
for root, dirs, files in os.walk(dataset_dir):
|
||||
for filename in fnmatch.filter(files, "*.hdf5"):
|
||||
if "features" in filename:
|
||||
continue
|
||||
if skip_mirrored_data and "mirror" in filename:
|
||||
continue
|
||||
hdf5_files.append(os.path.join(root, filename))
|
||||
print(f"Found {len(hdf5_files)} hdf5 files")
|
||||
return hdf5_files
|
||||
|
||||
|
||||
def BatchSampler(batch_size, episode_len_l, sample_weights):
|
||||
sample_probs = (
|
||||
np.array(sample_weights) / np.sum(sample_weights)
|
||||
if sample_weights is not None
|
||||
else None
|
||||
)
|
||||
# print("BatchSampler", sample_weights)
|
||||
sum_dataset_len_l = np.cumsum(
|
||||
[0] + [np.sum(episode_len) for episode_len in episode_len_l]
|
||||
)
|
||||
while True:
|
||||
batch = []
|
||||
for _ in range(batch_size):
|
||||
episode_idx = np.random.choice(len(episode_len_l), p=sample_probs)
|
||||
step_idx = np.random.randint(
|
||||
sum_dataset_len_l[episode_idx], sum_dataset_len_l[episode_idx + 1]
|
||||
)
|
||||
batch.append(step_idx)
|
||||
yield batch
|
||||
|
||||
|
||||
def load_data(
|
||||
dataset_dir_l,
|
||||
name_filter,
|
||||
camera_names,
|
||||
batch_size_train,
|
||||
batch_size_val,
|
||||
chunk_size,
|
||||
skip_mirrored_data=False,
|
||||
load_pretrain=False,
|
||||
policy_class=None,
|
||||
stats_dir_l=None,
|
||||
sample_weights=None,
|
||||
train_ratio=0.99,
|
||||
):
|
||||
if type(dataset_dir_l) == str:
|
||||
dataset_dir_l = [dataset_dir_l]
|
||||
dataset_path_list_list = [
|
||||
find_all_hdf5(dataset_dir, skip_mirrored_data) for dataset_dir in dataset_dir_l
|
||||
]
|
||||
num_episodes_0 = len(dataset_path_list_list[0])
|
||||
dataset_path_list = flatten_list(dataset_path_list_list)
|
||||
|
||||
dataset_path_list = [n for n in dataset_path_list if name_filter(n)]
|
||||
num_episodes_l = [
|
||||
len(dataset_path_list) for dataset_path_list in dataset_path_list_list
|
||||
]
|
||||
num_episodes_cumsum = np.cumsum(num_episodes_l)
|
||||
|
||||
# obtain train test split on dataset_dir_l[0]
|
||||
shuffled_episode_ids_0 = np.random.permutation(num_episodes_0)
|
||||
train_episode_ids_0 = shuffled_episode_ids_0[: int(train_ratio * num_episodes_0)]
|
||||
val_episode_ids_0 = shuffled_episode_ids_0[int(train_ratio * num_episodes_0) :]
|
||||
train_episode_ids_l = [train_episode_ids_0] + [
|
||||
np.arange(num_episodes) + num_episodes_cumsum[idx]
|
||||
for idx, num_episodes in enumerate(num_episodes_l[1:])
|
||||
]
|
||||
val_episode_ids_l = [val_episode_ids_0]
|
||||
train_episode_ids = np.concatenate(train_episode_ids_l)
|
||||
val_episode_ids = np.concatenate(val_episode_ids_l)
|
||||
print(
|
||||
f"\n\nData from: {dataset_dir_l}\n- Train on {[len(x) for x in train_episode_ids_l]} episodes\n- Test on {[len(x) for x in val_episode_ids_l]} episodes\n\n"
|
||||
)
|
||||
|
||||
# obtain normalization stats for qpos and action
|
||||
# if load_pretrain:
|
||||
# with open(os.path.join('/home/zfu/interbotix_ws/src/act/ckpts/pretrain_all', 'dataset_stats.pkl'), 'rb') as f:
|
||||
# norm_stats = pickle.load(f)
|
||||
# print('Loaded pretrain dataset stats')
|
||||
_, all_episode_len = get_norm_stats(dataset_path_list)
|
||||
train_episode_len_l = [
|
||||
[all_episode_len[i] for i in train_episode_ids]
|
||||
for train_episode_ids in train_episode_ids_l
|
||||
]
|
||||
val_episode_len_l = [
|
||||
[all_episode_len[i] for i in val_episode_ids]
|
||||
for val_episode_ids in val_episode_ids_l
|
||||
]
|
||||
|
||||
train_episode_len = flatten_list(train_episode_len_l)
|
||||
val_episode_len = flatten_list(val_episode_len_l)
|
||||
if stats_dir_l is None:
|
||||
stats_dir_l = dataset_dir_l
|
||||
elif type(stats_dir_l) == str:
|
||||
stats_dir_l = [stats_dir_l]
|
||||
norm_stats, _ = get_norm_stats(
|
||||
flatten_list(
|
||||
[find_all_hdf5(stats_dir, skip_mirrored_data) for stats_dir in stats_dir_l]
|
||||
)
|
||||
)
|
||||
print(f"Norm stats from: {stats_dir_l}")
|
||||
|
||||
batch_sampler_train = BatchSampler(
|
||||
batch_size_train, train_episode_len_l, sample_weights
|
||||
)
|
||||
batch_sampler_val = BatchSampler(batch_size_val, val_episode_len_l, None)
|
||||
|
||||
# print(f'train_episode_len: {train_episode_len}, val_episode_len: {val_episode_len}, train_episode_ids: {train_episode_ids}, val_episode_ids: {val_episode_ids}')
|
||||
|
||||
# construct dataset and dataloader
|
||||
train_dataset = EpisodicDataset(
|
||||
dataset_path_list,
|
||||
camera_names,
|
||||
norm_stats,
|
||||
train_episode_ids,
|
||||
train_episode_len,
|
||||
chunk_size,
|
||||
policy_class,
|
||||
)
|
||||
val_dataset = EpisodicDataset(
|
||||
dataset_path_list,
|
||||
camera_names,
|
||||
norm_stats,
|
||||
val_episode_ids,
|
||||
val_episode_len,
|
||||
chunk_size,
|
||||
policy_class,
|
||||
)
|
||||
train_num_workers = (
|
||||
(8 if os.getlogin() == "zfu" else 16) if train_dataset.augment_images else 2
|
||||
)
|
||||
val_num_workers = 8 if train_dataset.augment_images else 2
|
||||
print(
|
||||
f"Augment images: {train_dataset.augment_images}, train_num_workers: {train_num_workers}, val_num_workers: {val_num_workers}"
|
||||
)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_sampler=batch_sampler_train,
|
||||
pin_memory=True,
|
||||
num_workers=train_num_workers,
|
||||
prefetch_factor=2,
|
||||
)
|
||||
val_dataloader = DataLoader(
|
||||
val_dataset,
|
||||
batch_sampler=batch_sampler_val,
|
||||
pin_memory=True,
|
||||
num_workers=val_num_workers,
|
||||
prefetch_factor=2,
|
||||
)
|
||||
|
||||
return train_dataloader, val_dataloader, norm_stats, train_dataset.is_sim
|
||||
|
||||
|
||||
def calibrate_linear_vel(base_action, c=None):
|
||||
if c is None:
|
||||
c = 0.0 # 0.19
|
||||
v = base_action[..., 0]
|
||||
w = base_action[..., 1]
|
||||
base_action = base_action.copy()
|
||||
base_action[..., 0] = v - c * w
|
||||
return base_action
|
||||
|
||||
|
||||
def smooth_base_action(base_action):
|
||||
return np.stack(
|
||||
[
|
||||
np.convolve(base_action[:, i], np.ones(5) / 5, mode="same")
|
||||
for i in range(base_action.shape[1])
|
||||
],
|
||||
axis=-1,
|
||||
).astype(np.float32)
|
||||
|
||||
|
||||
def preprocess_base_action(base_action):
|
||||
# base_action = calibrate_linear_vel(base_action)
|
||||
base_action = smooth_base_action(base_action)
|
||||
|
||||
return base_action
|
||||
|
||||
|
||||
def postprocess_base_action(base_action):
|
||||
linear_vel, angular_vel = base_action
|
||||
linear_vel *= 1.0
|
||||
angular_vel *= 1.0
|
||||
# angular_vel = 0
|
||||
# if np.abs(linear_vel) < 0.05:
|
||||
# linear_vel = 0
|
||||
return np.array([linear_vel, angular_vel])
|
||||
|
||||
|
||||
### env utils
|
||||
|
||||
|
||||
def sample_box_pose():
|
||||
x_range = [0.0, 0.2]
|
||||
y_range = [0.4, 0.6]
|
||||
z_range = [0.05, 0.05]
|
||||
|
||||
ranges = np.vstack([x_range, y_range, z_range])
|
||||
cube_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
||||
|
||||
cube_quat = np.array([1, 0, 0, 0])
|
||||
return np.concatenate([cube_position, cube_quat])
|
||||
|
||||
|
||||
def sample_insertion_pose():
|
||||
# Peg
|
||||
x_range = [0.1, 0.2]
|
||||
y_range = [0.4, 0.6]
|
||||
z_range = [0.05, 0.05]
|
||||
|
||||
ranges = np.vstack([x_range, y_range, z_range])
|
||||
peg_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
||||
|
||||
peg_quat = np.array([1, 0, 0, 0])
|
||||
peg_pose = np.concatenate([peg_position, peg_quat])
|
||||
|
||||
# Socket
|
||||
x_range = [-0.2, -0.1]
|
||||
y_range = [0.4, 0.6]
|
||||
z_range = [0.05, 0.05]
|
||||
|
||||
ranges = np.vstack([x_range, y_range, z_range])
|
||||
socket_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
||||
|
||||
socket_quat = np.array([1, 0, 0, 0])
|
||||
socket_pose = np.concatenate([socket_position, socket_quat])
|
||||
|
||||
return peg_pose, socket_pose
|
||||
|
||||
|
||||
### helper functions
|
||||
|
||||
|
||||
def compute_dict_mean(epoch_dicts):
|
||||
result = {k: None for k in epoch_dicts[0]}
|
||||
num_items = len(epoch_dicts)
|
||||
for k in result:
|
||||
value_sum = 0
|
||||
for epoch_dict in epoch_dicts:
|
||||
value_sum += epoch_dict[k]
|
||||
result[k] = value_sum / num_items
|
||||
return result
|
||||
|
||||
|
||||
def detach_dict(d):
|
||||
new_d = dict()
|
||||
for k, v in d.items():
|
||||
new_d[k] = v.detach()
|
||||
return new_d
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
163
realman_src/realman_aloha/shadow_rm_act/test/test_camera.py
Normal file
163
realman_src/realman_aloha/shadow_rm_act/test/test_camera.py
Normal file
@@ -0,0 +1,163 @@
|
||||
from shadow_camera.realsense import RealSenseCamera
|
||||
from shadow_rm_robot.realman_arm import RmArm
|
||||
import yaml
|
||||
import time
|
||||
import multiprocessing
|
||||
import numpy as np
|
||||
import collections
|
||||
import logging
|
||||
import dm_env
|
||||
import tracemalloc
|
||||
|
||||
|
||||
class DeviceAloha:
|
||||
def __init__(self, aloha_config):
|
||||
"""
|
||||
初始化设备
|
||||
|
||||
Args:
|
||||
device_name (str): 设备名称
|
||||
"""
|
||||
config_left_arm = aloha_config["rm_left_arm"]
|
||||
config_right_arm = aloha_config["rm_right_arm"]
|
||||
config_head_camera = aloha_config["head_camera"]
|
||||
config_bottom_camera = aloha_config["bottom_camera"]
|
||||
config_left_camera = aloha_config["left_camera"]
|
||||
config_right_camera = aloha_config["right_camera"]
|
||||
self.init_left_arm_angle = aloha_config["init_left_arm_angle"]
|
||||
self.init_right_arm_angle = aloha_config["init_right_arm_angle"]
|
||||
self.arm_left = RmArm(config_left_arm)
|
||||
self.arm_right = RmArm(config_right_arm)
|
||||
self.camera_left = RealSenseCamera(config_head_camera, False)
|
||||
self.camera_right = RealSenseCamera(config_bottom_camera, False)
|
||||
self.camera_bottom = RealSenseCamera(config_left_camera, False)
|
||||
self.camera_top = RealSenseCamera(config_right_camera, False)
|
||||
self.camera_left.start_camera()
|
||||
self.camera_right.start_camera()
|
||||
self.camera_bottom.start_camera()
|
||||
self.camera_top.start_camera()
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
关闭摄像头
|
||||
"""
|
||||
self.camera_left.close()
|
||||
self.camera_right.close()
|
||||
self.camera_bottom.close()
|
||||
self.camera_top.close()
|
||||
|
||||
def get_qps(self):
|
||||
"""
|
||||
获取关节角度
|
||||
|
||||
Returns:
|
||||
np.array: 关节角度
|
||||
"""
|
||||
left_slave_arm_angle = self.arm_left.get_joint_angle()
|
||||
left_joint_angles_array = np.array(list(left_slave_arm_angle.values()))
|
||||
right_slave_arm_angle = self.arm_right.get_joint_angle()
|
||||
right_joint_angles_array = np.array(list(right_slave_arm_angle.values()))
|
||||
return np.concatenate([left_joint_angles_array, right_joint_angles_array])
|
||||
|
||||
def get_qvel(self):
|
||||
"""
|
||||
获取关节速度
|
||||
|
||||
Returns:
|
||||
np.array: 关节速度
|
||||
"""
|
||||
left_slave_arm_velocity = self.arm_left.get_joint_velocity()
|
||||
left_joint_velocity_array = np.array(list(left_slave_arm_velocity.values()))
|
||||
right_slave_arm_velocity = self.arm_right.get_joint_velocity()
|
||||
right_joint_velocity_array = np.array(list(right_slave_arm_velocity.values()))
|
||||
return np.concatenate([left_joint_velocity_array, right_joint_velocity_array])
|
||||
|
||||
def get_effort(self):
|
||||
"""
|
||||
获取关节力
|
||||
|
||||
Returns:
|
||||
np.array: 关节力
|
||||
"""
|
||||
left_slave_arm_effort = self.arm_left.get_joint_effort()
|
||||
left_joint_effort_array = np.array(list(left_slave_arm_effort.values()))
|
||||
right_slave_arm_effort = self.arm_right.get_joint_effort()
|
||||
right_joint_effort_array = np.array(list(right_slave_arm_effort.values()))
|
||||
return np.concatenate([left_joint_effort_array, right_joint_effort_array])
|
||||
|
||||
def get_images(self):
|
||||
"""
|
||||
获取图像
|
||||
|
||||
Returns:
|
||||
dict: 图像字典
|
||||
"""
|
||||
top_image, _, _, _ = self.camera_top.read_frame(True, False, False, False)
|
||||
bottom_image, _, _, _ = self.camera_bottom.read_frame(True, False, False, False)
|
||||
left_image, _, _, _ = self.camera_left.read_frame(True, False, False, False)
|
||||
right_image, _, _, _ = self.camera_right.read_frame(True, False, False, False)
|
||||
return {
|
||||
"cam_high": top_image,
|
||||
"cam_low": bottom_image,
|
||||
"cam_left": left_image,
|
||||
"cam_right": right_image,
|
||||
}
|
||||
|
||||
def get_observation(self):
|
||||
obs = collections.OrderedDict()
|
||||
obs["qpos"] = self.get_qps()
|
||||
obs["qvel"] = self.get_qvel()
|
||||
obs["effort"] = self.get_effort()
|
||||
obs["images"] = self.get_images()
|
||||
# self.get_images()
|
||||
return obs
|
||||
|
||||
def reset(self):
|
||||
logging.info("Resetting the environment")
|
||||
_ = self.arm_left.set_joint_position(self.init_left_arm_angle[0:6])
|
||||
_ = self.arm_right.set_joint_position(self.init_right_arm_angle[0:6])
|
||||
self.arm_left.set_gripper_position(0)
|
||||
self.arm_right.set_gripper_position(0)
|
||||
return dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.FIRST,
|
||||
reward=0,
|
||||
discount=None,
|
||||
observation=self.get_observation(),
|
||||
)
|
||||
|
||||
def step(self, target_angle):
|
||||
self.arm_left.set_joint_canfd_position(target_angle[0:6])
|
||||
self.arm_right.set_joint_canfd_position(target_angle[7:13])
|
||||
self.arm_left.set_gripper_position(target_angle[6])
|
||||
self.arm_right.set_gripper_position(target_angle[13])
|
||||
return dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.MID,
|
||||
reward=0,
|
||||
discount=None,
|
||||
observation=self.get_observation(),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
with open("/home/rm/code/shadow_act/config/config.yaml", "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
aloha_config = config["robot_env"]
|
||||
device = DeviceAloha(aloha_config)
|
||||
device.reset()
|
||||
image_list = []
|
||||
tager_angle = np.concatenate([device.init_left_arm_angle, device.init_right_arm_angle])
|
||||
while True:
|
||||
tracemalloc.start() # 启动内存跟踪
|
||||
|
||||
tager_angle = np.array([angle + 0.1 if i not in [6, 13] else angle for i, angle in enumerate(tager_angle)])
|
||||
time_step = time.time()
|
||||
timestep = device.step(tager_angle)
|
||||
logging.info(f"Time: {time.time() - time_step}")
|
||||
image_list.append(timestep.observation["images"])
|
||||
snapshot = tracemalloc.take_snapshot()
|
||||
top_stats = snapshot.statistics('lineno')
|
||||
# del timestep
|
||||
print("[ Top 10 ]")
|
||||
for stat in top_stats[:10]:
|
||||
print(stat)
|
||||
# logging.info(f"Images: {obs}")
|
||||
32
realman_src/realman_aloha/shadow_rm_act/test/test_h5.py
Normal file
32
realman_src/realman_aloha/shadow_rm_act/test/test_h5.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import os
|
||||
# import time
|
||||
import yaml
|
||||
import torch
|
||||
import pickle
|
||||
import dm_env
|
||||
import logging
|
||||
import collections
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
import matplotlib.pyplot as plt
|
||||
from torchvision import transforms
|
||||
from shadow_rm_robot.realman_arm import RmArm
|
||||
from shadow_camera.realsense import RealSenseCamera
|
||||
from shadow_act.models.latent_model import Latent_Model_Transformer
|
||||
from shadow_act.network.policy import ACTPolicy, CNNMLPPolicy, DiffusionPolicy
|
||||
from shadow_act.utils.utils import (
|
||||
load_data,
|
||||
sample_box_pose,
|
||||
sample_insertion_pose,
|
||||
compute_dict_mean,
|
||||
set_seed,
|
||||
detach_dict,
|
||||
)
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
print('daasdas')
|
||||
147
realman_src/realman_aloha/shadow_rm_act/visualize_episodes.py
Normal file
147
realman_src/realman_aloha/shadow_rm_act/visualize_episodes.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
import h5py
|
||||
import argparse
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from constants import DT
|
||||
|
||||
import IPython
|
||||
e = IPython.embed
|
||||
|
||||
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
|
||||
STATE_NAMES = JOINT_NAMES + ["gripper"]
|
||||
|
||||
def load_hdf5(dataset_dir, dataset_name):
|
||||
dataset_path = os.path.join(dataset_dir, dataset_name + '.hdf5')
|
||||
if not os.path.isfile(dataset_path):
|
||||
print(f'Dataset does not exist at \n{dataset_path}\n')
|
||||
exit()
|
||||
|
||||
with h5py.File(dataset_path, 'r') as root:
|
||||
is_sim = root.attrs['sim']
|
||||
qpos = root['/observations/qpos'][()]
|
||||
qvel = root['/observations/qvel'][()]
|
||||
action = root['/action'][()]
|
||||
image_dict = dict()
|
||||
for cam_name in root[f'/observations/images/'].keys():
|
||||
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()]
|
||||
|
||||
return qpos, qvel, action, image_dict
|
||||
|
||||
def main(args):
|
||||
dataset_dir = args['dataset_dir']
|
||||
episode_idx = args['episode_idx']
|
||||
dataset_name = f'episode_{episode_idx}'
|
||||
|
||||
qpos, qvel, action, image_dict = load_hdf5(dataset_dir, dataset_name)
|
||||
save_videos(image_dict, DT, video_path=os.path.join(dataset_dir, dataset_name + '_video.mp4'))
|
||||
visualize_joints(qpos, action, plot_path=os.path.join(dataset_dir, dataset_name + '_qpos.png'))
|
||||
# visualize_timestamp(t_list, dataset_path) # TODO addn timestamp back
|
||||
|
||||
|
||||
def save_videos(video, dt, video_path=None):
|
||||
if isinstance(video, list):
|
||||
cam_names = list(video[0].keys())
|
||||
h, w, _ = video[0][cam_names[0]].shape
|
||||
w = w * len(cam_names)
|
||||
fps = int(1/dt)
|
||||
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||
for ts, image_dict in enumerate(video):
|
||||
images = []
|
||||
for cam_name in cam_names:
|
||||
image = image_dict[cam_name]
|
||||
image = image[:, :, [2, 1, 0]] # swap B and R channel
|
||||
images.append(image)
|
||||
images = np.concatenate(images, axis=1)
|
||||
out.write(images)
|
||||
out.release()
|
||||
print(f'Saved video to: {video_path}')
|
||||
elif isinstance(video, dict):
|
||||
cam_names = list(video.keys())
|
||||
all_cam_videos = []
|
||||
for cam_name in cam_names:
|
||||
all_cam_videos.append(video[cam_name])
|
||||
all_cam_videos = np.concatenate(all_cam_videos, axis=2) # width dimension
|
||||
|
||||
n_frames, h, w, _ = all_cam_videos.shape
|
||||
fps = int(1 / dt)
|
||||
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||
for t in range(n_frames):
|
||||
image = all_cam_videos[t]
|
||||
image = image[:, :, [2, 1, 0]] # swap B and R channel
|
||||
out.write(image)
|
||||
out.release()
|
||||
print(f'Saved video to: {video_path}')
|
||||
|
||||
|
||||
def visualize_joints(qpos_list, command_list, plot_path=None, ylim=None, label_overwrite=None):
|
||||
if label_overwrite:
|
||||
label1, label2 = label_overwrite
|
||||
else:
|
||||
label1, label2 = 'State', 'Command'
|
||||
|
||||
qpos = np.array(qpos_list) # ts, dim
|
||||
command = np.array(command_list)
|
||||
num_ts, num_dim = qpos.shape
|
||||
h, w = 2, num_dim
|
||||
num_figs = num_dim
|
||||
fig, axs = plt.subplots(num_figs, 1, figsize=(w, h * num_figs))
|
||||
|
||||
# plot joint state
|
||||
all_names = [name + '_left' for name in STATE_NAMES] + [name + '_right' for name in STATE_NAMES]
|
||||
for dim_idx in range(num_dim):
|
||||
ax = axs[dim_idx]
|
||||
ax.plot(qpos[:, dim_idx], label=label1)
|
||||
ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}')
|
||||
ax.legend()
|
||||
|
||||
# plot arm command
|
||||
for dim_idx in range(num_dim):
|
||||
ax = axs[dim_idx]
|
||||
ax.plot(command[:, dim_idx], label=label2)
|
||||
ax.legend()
|
||||
|
||||
if ylim:
|
||||
for dim_idx in range(num_dim):
|
||||
ax = axs[dim_idx]
|
||||
ax.set_ylim(ylim)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(plot_path)
|
||||
print(f'Saved qpos plot to: {plot_path}')
|
||||
plt.close()
|
||||
|
||||
def visualize_timestamp(t_list, dataset_path):
|
||||
plot_path = dataset_path.replace('.pkl', '_timestamp.png')
|
||||
h, w = 4, 10
|
||||
fig, axs = plt.subplots(2, 1, figsize=(w, h*2))
|
||||
# process t_list
|
||||
t_float = []
|
||||
for secs, nsecs in t_list:
|
||||
t_float.append(secs + nsecs * 10E-10)
|
||||
t_float = np.array(t_float)
|
||||
|
||||
ax = axs[0]
|
||||
ax.plot(np.arange(len(t_float)), t_float)
|
||||
ax.set_title(f'Camera frame timestamps')
|
||||
ax.set_xlabel('timestep')
|
||||
ax.set_ylabel('time (sec)')
|
||||
|
||||
ax = axs[1]
|
||||
ax.plot(np.arange(len(t_float)-1), t_float[:-1] - t_float[1:])
|
||||
ax.set_title(f'dt')
|
||||
ax.set_xlabel('timestep')
|
||||
ax.set_ylabel('time (sec)')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(plot_path)
|
||||
print(f'Saved timestamp plot to: {plot_path}')
|
||||
plt.close()
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--dataset_dir', action='store', type=str, help='Dataset dir.', required=True)
|
||||
parser.add_argument('--episode_idx', action='store', type=int, help='Episode index.', required=False)
|
||||
main(vars(parser.parse_args()))
|
||||
10
realman_src/realman_aloha/shadow_rm_aloha/.gitignore
vendored
Normal file
10
realman_src/realman_aloha/shadow_rm_aloha/.gitignore
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
__pycache__/
|
||||
build/
|
||||
devel/
|
||||
dist/
|
||||
data/
|
||||
.catkin_workspace
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pt
|
||||
.vscode/
|
||||
3
realman_src/realman_aloha/shadow_rm_aloha/.idea/.gitignore
generated
vendored
Normal file
3
realman_src/realman_aloha/shadow_rm_aloha/.idea/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
# 默认忽略的文件
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
1
realman_src/realman_aloha/shadow_rm_aloha/.idea/.name
generated
Normal file
1
realman_src/realman_aloha/shadow_rm_aloha/.idea/.name
generated
Normal file
@@ -0,0 +1 @@
|
||||
aloha_data_synchronizer.py
|
||||
17
realman_src/realman_aloha/shadow_rm_aloha/.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
17
realman_src/realman_aloha/shadow_rm_aloha/.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
@@ -0,0 +1,17 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<profile version="1.0">
|
||||
<option name="myName" value="Project Default" />
|
||||
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||
<option name="ignoredPackages">
|
||||
<value>
|
||||
<list size="4">
|
||||
<item index="0" class="java.lang.String" itemvalue="tensorboard" />
|
||||
<item index="1" class="java.lang.String" itemvalue="thop" />
|
||||
<item index="2" class="java.lang.String" itemvalue="torch" />
|
||||
<item index="3" class="java.lang.String" itemvalue="torchvision" />
|
||||
</list>
|
||||
</value>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
</profile>
|
||||
</component>
|
||||
6
realman_src/realman_aloha/shadow_rm_aloha/.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
realman_src/realman_aloha/shadow_rm_aloha/.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
||||
7
realman_src/realman_aloha/shadow_rm_aloha/.idea/misc.xml
generated
Normal file
7
realman_src/realman_aloha/shadow_rm_aloha/.idea/misc.xml
generated
Normal file
@@ -0,0 +1,7 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="Python 3.11 (随箱软件)" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.11 (随箱软件)" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
8
realman_src/realman_aloha/shadow_rm_aloha/.idea/modules.xml
generated
Normal file
8
realman_src/realman_aloha/shadow_rm_aloha/.idea/modules.xml
generated
Normal file
@@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/shadow_rm_aloha.iml" filepath="$PROJECT_DIR$/.idea/shadow_rm_aloha.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
||||
12
realman_src/realman_aloha/shadow_rm_aloha/.idea/shadow_rm_aloha.iml
generated
Normal file
12
realman_src/realman_aloha/shadow_rm_aloha/.idea/shadow_rm_aloha.iml
generated
Normal file
@@ -0,0 +1,12 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="PyDocumentationSettings">
|
||||
<option name="format" value="PLAIN" />
|
||||
<option name="myDocStringFormat" value="Plain" />
|
||||
</component>
|
||||
</module>
|
||||
0
realman_src/realman_aloha/shadow_rm_aloha/README.md
Normal file
0
realman_src/realman_aloha/shadow_rm_aloha/README.md
Normal file
@@ -0,0 +1,38 @@
|
||||
dataset_dir: '/home/rm/code/shadow_rm_aloha/data/dataset'
|
||||
dataset_name: 'episode'
|
||||
max_timesteps: 500
|
||||
state_dim: 14
|
||||
overwrite: False
|
||||
arm_axis: 6
|
||||
camera_names:
|
||||
- 'cam_high'
|
||||
- 'cam_low'
|
||||
- 'cam_left'
|
||||
- 'cam_right'
|
||||
ros_topics:
|
||||
camera_left: '/camera_left/rgb/image_raw'
|
||||
camera_right: '/camera_right/rgb/image_raw'
|
||||
camera_bottom: '/camera_bottom/rgb/image_raw'
|
||||
camera_head: '/camera_head/rgb/image_raw'
|
||||
left_master_arm: '/left_master_arm_joint_states'
|
||||
left_slave_arm: '/left_slave_arm_joint_states'
|
||||
right_master_arm: '/right_master_arm_joint_states'
|
||||
right_slave_arm: '/right_slave_arm_joint_states'
|
||||
left_aloha_state: '/left_slave_arm_aloha_state'
|
||||
right_aloha_state: '/right_slave_arm_aloha_state'
|
||||
robot_env: {
|
||||
# TODO change the path to the correct one
|
||||
rm_left_arm: '/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml',
|
||||
rm_right_arm: '/home/rm/code/shadow_rm_aloha/config/rm_right_arm.yaml',
|
||||
arm_axis: 6,
|
||||
head_camera: '241122071186',
|
||||
bottom_camera: '152122078546',
|
||||
left_camera: '150622070125',
|
||||
right_camera: '151222072576',
|
||||
init_left_arm_angle: [7.235, 31.816, 51.237, 2.463, 91.054, 12.04, 0.0],
|
||||
init_right_arm_angle: [-6.155, 33.925, 62.137, -1.672, 87.892, -3.868, 0.0]
|
||||
# init_left_arm_angle: [6.681, 38.496, 66.093, -1.141, 74.529, 3.076, 0.0],
|
||||
# init_right_arm_angle: [-4.79, 37.062, 72.393, -0.477, 68.593, -9.526, 0.0]
|
||||
# init_left_arm_angle: [6.45, 66.093, 2.9, 20.919, -1.491, 100.756, 18.808, 0.617],
|
||||
# init_right_arm_angle: [166.953, -33.575, -163.917, 73.3, -9.581, 69.51, 0.876]
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
arm_ip: "192.168.1.18"
|
||||
arm_port: 8080
|
||||
arm_axis: 6
|
||||
local_ip: "192.168.1.101"
|
||||
local_port: 8089
|
||||
# arm_ki: [7, 7, 7, 3, 3, 3, 3] # rm75
|
||||
arm_ki: [7, 7, 7, 3, 3, 3] # rm65
|
||||
get_vel: True
|
||||
get_torque: True
|
||||
@@ -0,0 +1,9 @@
|
||||
arm_ip: "192.168.1.19"
|
||||
arm_port: 8080
|
||||
arm_axis: 6
|
||||
local_ip: "192.168.1.101"
|
||||
local_port: 8090
|
||||
# arm_ki: [7, 7, 7, 3, 3, 3, 3] # rm75
|
||||
arm_ki: [7, 7, 7, 3, 3, 3] # rm65
|
||||
get_vel: True
|
||||
get_torque: True
|
||||
@@ -0,0 +1,4 @@
|
||||
port: /dev/ttyUSB1
|
||||
baudrate: 460800
|
||||
hex_data: "55 AA 02 00 00 67"
|
||||
arm_axis: 6
|
||||
@@ -0,0 +1,4 @@
|
||||
port: /dev/ttyUSB0
|
||||
baudrate: 460800
|
||||
hex_data: "55 AA 02 00 00 67"
|
||||
arm_axis: 6
|
||||
@@ -0,0 +1,6 @@
|
||||
dataset_dir: '/home/rm/code/shadow_rm_aloha/data/dataset/20250102'
|
||||
dataset_name: 'episode'
|
||||
episode_idx: 1
|
||||
FPS: 30
|
||||
# joint_names: ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate", "J7"] # 7 joints
|
||||
joint_names: ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"] # 6 joints
|
||||
39
realman_src/realman_aloha/shadow_rm_aloha/pyproject.toml
Normal file
39
realman_src/realman_aloha/shadow_rm_aloha/pyproject.toml
Normal file
@@ -0,0 +1,39 @@
|
||||
[tool.poetry]
|
||||
name = "shadow_rm_aloha"
|
||||
version = "0.1.1"
|
||||
description = "aloha package, use D435 and Realman robot arm to build aloha to collect data"
|
||||
readme = "README.md"
|
||||
authors = ["Shadow <qiuchengzhan@gmail.com>"]
|
||||
license = "MIT"
|
||||
# include = ["realman_vision/pytransform/_pytransform.so",]
|
||||
classifiers = [
|
||||
"Operating System :: POSIX :: Linux amd64",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10"
|
||||
matplotlib = ">=3.9.2"
|
||||
h5py = ">=3.12.1"
|
||||
# rospy = ">=1.17.0"
|
||||
# shadow_rm_robot = { git = "https://github.com/Shadow2223/shadow_rm_robot.git", branch = "main" }
|
||||
# shadow_camera = { git = "https://github.com/Shadow2223/shadow_camera.git", branch = "main" }
|
||||
|
||||
|
||||
|
||||
[tool.poetry.dev-dependencies] # 列出开发时所需的依赖项,比如测试、文档生成等工具。
|
||||
pytest = ">=8.3"
|
||||
black = ">=24.10.0"
|
||||
|
||||
|
||||
|
||||
[tool.poetry.plugins."scripts"] # 定义命令行脚本,使得用户可以通过命令行运行指定的函数。
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.8.4"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
@@ -0,0 +1,42 @@
|
||||
cmake_minimum_required(VERSION 3.0.2)
|
||||
project(shadow_rm_aloha)
|
||||
|
||||
find_package(catkin REQUIRED COMPONENTS
|
||||
rospy
|
||||
sensor_msgs
|
||||
cv_bridge
|
||||
image_transport
|
||||
std_msgs
|
||||
message_generation
|
||||
)
|
||||
|
||||
add_service_files(
|
||||
FILES
|
||||
GetArmStatus.srv
|
||||
GetImage.srv
|
||||
MoveArm.srv
|
||||
)
|
||||
|
||||
generate_messages(
|
||||
DEPENDENCIES
|
||||
sensor_msgs
|
||||
std_msgs
|
||||
)
|
||||
|
||||
catkin_package(
|
||||
CATKIN_DEPENDS message_runtime rospy std_msgs
|
||||
)
|
||||
|
||||
include_directories(
|
||||
${catkin_INCLUDE_DIRS}
|
||||
)
|
||||
|
||||
install(PROGRAMS
|
||||
arm_node/slave_arm_publisher.py
|
||||
arm_node/master_arm_publisher.py
|
||||
arm_node/slave_arm_pub_sub.py
|
||||
camera_node/camera_publisher.py
|
||||
data_sub_process/aloha_data_synchronizer.py
|
||||
data_sub_process/aloha_data_collect.py
|
||||
DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user