init commit
This commit is contained in:
101
workflows/simbox/core/controllers/README.md
Normal file
101
workflows/simbox/core/controllers/README.md
Normal file
@@ -0,0 +1,101 @@
|
||||
# Arm Controllers
|
||||
|
||||
Template-based arm controllers for CuRobo motion planning. All arm controllers inherit from `TemplateController` and customize behavior via overrides.
|
||||
|
||||
## Available controllers
|
||||
|
||||
| Controller | Robot | Notes |
|
||||
|------------|--------|------|
|
||||
| `FR3Controller` | Franka FR3 (Panda arm, 7+1 gripper) | Single arm only; larger collision cache (1000). |
|
||||
| `FrankaRobotiq85Controller` | Franka + Robotiq 85 (7+2 gripper) | Single arm only; custom gripper action (inverted, clip 0–5). |
|
||||
| `Genie1Controller` | Genie1 dual arm (7 DoF per arm) | Left/right via `robot_file`; path-selection weights. |
|
||||
| `Lift2Controller` | ARX-Lift2 dual arm (6 DoF arm) | Left/right; custom world (cuboid offset 0.02), grasp axis 0 (x), in-plane rotation index 5. |
|
||||
| `SplitAlohaController` | Agilex Split Aloha dual arm (6 DoF per arm) | Left/right; grasp axis 2 (z); optional `joint_ctrl`. |
|
||||
|
||||
Register a controller by importing it (see `__init__.py`) so it is added to `CONTROLLER_DICT`.
|
||||
|
||||
---
|
||||
|
||||
## Customizing a robot arm controller
|
||||
|
||||
Subclass `TemplateController` and implement or override the following.
|
||||
|
||||
### 1. Required: `_configure_joint_indices(self, robot_file: str)`
|
||||
|
||||
Set joint names and indices for the planner and the simulation articulation.
|
||||
|
||||
You must set:
|
||||
|
||||
- **`self.raw_js_names`** – Joint names in the **planner / CuRobo** order (arm only, no gripper). Used for `get_ordered_joint_state(raw_js_names)` when building `cmd_plan`.
|
||||
- **`self.cmd_js_names`** – Same as `raw_js_names` use the **scene/articulation** names in the robot usd (e.g. `fl_joint1`… or `idx21_arm_l_joint1`…).
|
||||
- **`self.arm_indices`** – Indices of arm joints in the **simulation** `dof_names` (e.g. `np.array([0,1,2,3,4,5,6])`).
|
||||
- **`self.gripper_indices`** – Indices of gripper joints in the simulation (e.g. `np.array([7])` or `np.array([7,8])`).
|
||||
- **`self.reference_prim_path`** – Prim path used for collision reference (e.g. `self.task.robots[self.name].fl_base_path` for left arm).
|
||||
- **`self.lr_name`** – `"left"` or `"right"` (for dual-arm).
|
||||
- **`self._gripper_state`** – Initial gripper state from robot (e.g. `1.0 if self.robot.left_gripper_state == 1.0 else -1.0`). By convention, `1.0` means **open**, `-1.0` means **closed**.
|
||||
- **`self._gripper_joint_position`** – Initial gripper joint position(s), shape matching `gripper_indices` (e.g. `np.array([1.0])` or `np.array([5.0, 5.0])`).
|
||||
|
||||
For dual-arm, branch on `"left"` / `"right"` in `robot_file` and set the above per arm. For single-arm, only implement the arm you support and we set it as left arm.
|
||||
|
||||
### 2. Required: `_get_default_ignore_substring(self) -> List[str]`
|
||||
|
||||
Return the default list of name substrings for collision filtering (e.g. `["material", "Plane", "conveyor", "scene", "table"]`). The controller name is appended automatically. Override to add or remove terms (e.g. `"fluid"` for some setups).
|
||||
|
||||
### 3. Required: `get_gripper_action(self)`
|
||||
|
||||
Map the logical gripper state to gripper joint targets.
|
||||
|
||||
- **Input**: uses `self._gripper_state` (1.0 = open, -1.0 = closed) and `self._gripper_joint_position` as the magnitude / joint-space template.
|
||||
- **Default mapping**: for simple parallel grippers, a good starting point is:
|
||||
|
||||
```python
|
||||
def get_gripper_action(self):
|
||||
return np.clip(self._gripper_state * self._gripper_joint_position, 0.0, 0.04)
|
||||
```
|
||||
|
||||
- **Robot-specific variants**: some robots change the range or sign (e.g. Robotiq85 uses two joints, inverted sign, and clips to `[0, 5]`). Adjust the formula and clip range, but keep the convention that `self._gripper_state` is `1.0` for **open** and `-1.0` for **closed**.
|
||||
|
||||
### 4. Optional overrides
|
||||
|
||||
Override only when the default template behavior is wrong for your robot.
|
||||
|
||||
- **`_load_world(self, use_default: bool = True)`**
|
||||
Default uses `WorldConfig()` when `use_default=True`, and when `False` uses a table with cuboid z offset `10.5`. Override if your table height or world is different (e.g. Genie1 uses `5.02`, Lift2 uses `0.02`).
|
||||
|
||||
- **`_get_motion_gen_collision_cache(self)`**
|
||||
Default: `{"obb": 700, "mesh": 700}`. Override to change cache size (e.g. FR3 uses `1000`).
|
||||
|
||||
- **`_get_grasp_approach_linear_axis(self) -> int`**
|
||||
Default: `2` (z-axis). Override if your grasp approach constraint uses another axis (e.g. Lift2 uses `0` for x).
|
||||
|
||||
- **`_get_sort_path_weights(self) -> Optional[List[float]]`**
|
||||
Default: `None` (equal weights). Override to pass per-joint weights for batch path selection (e.g. Genie1 uses `[1,1,1,1,3,3,1]` for 7 joints).
|
||||
|
||||
- **`get_gripper_action(self)`**
|
||||
Default: `np.clip(self._gripper_state * self._gripper_joint_position, 0.0, 0.04)`. Override if your gripper mapping or range differs (e.g. Robotiq85: inverted sign and clip to `[0, 5]`).
|
||||
|
||||
### 5. Registration
|
||||
|
||||
- Use the `@register_controller` decorator on your class.
|
||||
- Import the new controller in `__init__.py` so it is registered in `CONTROLLER_DICT`.
|
||||
|
||||
### 6. Robot config (YAML)
|
||||
|
||||
Your robot must have a CuRobo config YAML (passed as `robot_file`) with at least:
|
||||
|
||||
- `robot_cfg` with `kinematics` (e.g. `urdf_path`, `base_link`, `ee_link`).
|
||||
|
||||
Template uses this for `_load_robot`, `_load_kin_model`, and `_init_motion_gen`; no code change needed if the YAML is correct.
|
||||
|
||||
---
|
||||
|
||||
## Summary checklist for a new arm
|
||||
|
||||
1. Add a new file, e.g. `myrobot_controller.py`.
|
||||
2. Subclass `TemplateController` and apply `@register_controller`.
|
||||
3. Implement **`_configure_joint_indices(robot_file)`** (joint names and indices for planner and sim).
|
||||
4. Implement **`_get_default_ignore_substring()`** (collision ignore list).
|
||||
5. Implement **`get_gripper_action(self)`** to map `self._gripper_state` (1.0 = open, -1.0 = closed) and `self._gripper_joint_position` to gripper joint targets (clip to a sensible range for your hardware).
|
||||
6. Override **`_load_world`** only if table/world differs from default.
|
||||
7. Override **`_get_motion_gen_collision_cache`** / **`_get_grasp_approach_linear_axis`** / **`_get_sort_path_weights`** only if needed.
|
||||
8. Import the new controller in `__init__.py`.
|
||||
28
workflows/simbox/core/controllers/__init__.py
Normal file
28
workflows/simbox/core/controllers/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Template-based controllers. Import subclasses to register them with CONTROLLER_DICT."""
|
||||
|
||||
from core.controllers.base_controller import CONTROLLER_DICT
|
||||
from core.controllers.fr3_controller import FR3Controller
|
||||
from core.controllers.frankarobotiq85_controller import FrankaRobotiq85Controller
|
||||
from core.controllers.genie1_controller import Genie1Controller
|
||||
from core.controllers.lift2_controller import Lift2Controller
|
||||
from core.controllers.splitaloha_controller import SplitAlohaController
|
||||
from core.controllers.template_controller import TemplateController
|
||||
|
||||
__all__ = [
|
||||
"TemplateController",
|
||||
"FR3Controller",
|
||||
"FrankaRobotiq85Controller",
|
||||
"Genie1Controller",
|
||||
"Lift2Controller",
|
||||
"SplitAlohaController",
|
||||
]
|
||||
|
||||
|
||||
def get_controller_cls(category_name):
|
||||
"""Get controller class by category name."""
|
||||
return CONTROLLER_DICT[category_name]
|
||||
|
||||
|
||||
def get_controller_dict():
|
||||
"""Get controller dictionary."""
|
||||
return CONTROLLER_DICT
|
||||
11
workflows/simbox/core/controllers/base_controller.py
Normal file
11
workflows/simbox/core/controllers/base_controller.py
Normal file
@@ -0,0 +1,11 @@
|
||||
CONTROLLER_DICT = {}
|
||||
|
||||
|
||||
def register_controller(target_class):
|
||||
# key = "_".join(re.sub(r"([A-Z0-9])", r" \1", target_class.__name__).split()).lower()
|
||||
key = target_class.__name__
|
||||
assert key.endswith("Controller")
|
||||
key = key.removesuffix("Controller")
|
||||
# assert key not in CONTROLLER_DICT
|
||||
CONTROLLER_DICT[key] = target_class
|
||||
return target_class
|
||||
37
workflows/simbox/core/controllers/fr3_controller.py
Normal file
37
workflows/simbox/core/controllers/fr3_controller.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""FR3 controller – template-based."""
|
||||
|
||||
import numpy as np
|
||||
from core.controllers.base_controller import register_controller
|
||||
from core.controllers.template_controller import TemplateController
|
||||
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
@register_controller
|
||||
class FR3Controller(TemplateController):
|
||||
def _get_default_ignore_substring(self):
|
||||
return ["material", "Plane", "conveyor", "scene", "table"]
|
||||
|
||||
def _configure_joint_indices(self, robot_file: str) -> None:
|
||||
self.raw_js_names = [
|
||||
"panda_joint1",
|
||||
"panda_joint2",
|
||||
"panda_joint3",
|
||||
"panda_joint4",
|
||||
"panda_joint5",
|
||||
"panda_joint6",
|
||||
"panda_joint7",
|
||||
]
|
||||
if "left" in robot_file:
|
||||
self.cmd_js_names = list(self.raw_js_names)
|
||||
self.arm_indices = np.array(self.robot.cfg["left_joint_indices"])
|
||||
self.gripper_indices = np.array(self.robot.cfg["left_gripper_indices"])
|
||||
self.reference_prim_path = self.task.robots[self.name].fl_base_path
|
||||
self.lr_name = "left"
|
||||
self._gripper_state = 1.0 if self.robot.left_gripper_state == 1.0 else -1.0
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self._gripper_joint_position = np.array([1.0])
|
||||
|
||||
def _get_motion_gen_collision_cache(self):
|
||||
"""FR3 uses larger collision cache (1000) for MotionGenConfig than template default (700)."""
|
||||
return {"obb": 1000, "mesh": 1000}
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Franka Robotiq85 controller – template-based."""
|
||||
|
||||
import numpy as np
|
||||
from core.controllers.base_controller import register_controller
|
||||
from core.controllers.template_controller import TemplateController
|
||||
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
@register_controller
|
||||
class FrankaRobotiq85Controller(TemplateController):
|
||||
def _get_default_ignore_substring(self):
|
||||
return ["material", "Plane", "conveyor", "scene"]
|
||||
|
||||
def _configure_joint_indices(self, robot_file: str) -> None:
|
||||
self.raw_js_names = [
|
||||
"panda_joint1",
|
||||
"panda_joint2",
|
||||
"panda_joint3",
|
||||
"panda_joint4",
|
||||
"panda_joint5",
|
||||
"panda_joint6",
|
||||
"panda_joint7",
|
||||
]
|
||||
if "left" in robot_file:
|
||||
self.cmd_js_names = list(self.raw_js_names)
|
||||
self.arm_indices = np.array(self.robot.cfg["left_joint_indices"])
|
||||
self.gripper_indices = np.array(self.robot.cfg["left_gripper_indices"])
|
||||
self.reference_prim_path = self.task.robots[self.name].fl_base_path
|
||||
self.lr_name = "left"
|
||||
self._gripper_state = 1.0 if self.robot.left_gripper_state == 1.0 else -1.0
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self._gripper_joint_position = np.array([5.0, 5.0])
|
||||
|
||||
def get_gripper_action(self):
|
||||
"""Robotiq85: inverted sign and clip to [0, 5] (two gripper joints)."""
|
||||
return np.clip((-1) * self._gripper_state * self._gripper_joint_position, 0.0, 5)
|
||||
58
workflows/simbox/core/controllers/genie1_controller.py
Normal file
58
workflows/simbox/core/controllers/genie1_controller.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Genie1 dual-arm controller – template-based."""
|
||||
|
||||
import numpy as np
|
||||
from core.controllers.base_controller import register_controller
|
||||
from core.controllers.template_controller import TemplateController
|
||||
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
@register_controller
|
||||
class Genie1Controller(TemplateController):
|
||||
def _get_default_ignore_substring(self):
|
||||
return ["material", "Plane", "conveyor", "scene", "table", "fluid"]
|
||||
|
||||
def _configure_joint_indices(self, robot_file: str) -> None:
|
||||
if "left" in robot_file:
|
||||
self.cmd_js_names = [
|
||||
"idx21_arm_l_joint1",
|
||||
"idx22_arm_l_joint2",
|
||||
"idx23_arm_l_joint3",
|
||||
"idx24_arm_l_joint4",
|
||||
"idx25_arm_l_joint5",
|
||||
"idx26_arm_l_joint6",
|
||||
"idx27_arm_l_joint7",
|
||||
]
|
||||
self.arm_indices = np.array(self.robot.cfg["left_joint_indices"])
|
||||
self.gripper_indices = np.array(self.robot.cfg["left_gripper_indices"])
|
||||
self.reference_prim_path = self.task.robots[self.name].fl_base_path
|
||||
self.lr_name = "left"
|
||||
self._gripper_state = 1.0 if self.robot.left_gripper_state == 1.0 else -1.0
|
||||
elif "right" in robot_file:
|
||||
self.cmd_js_names = [
|
||||
"idx61_arm_r_joint1",
|
||||
"idx62_arm_r_joint2",
|
||||
"idx63_arm_r_joint3",
|
||||
"idx64_arm_r_joint4",
|
||||
"idx65_arm_r_joint5",
|
||||
"idx66_arm_r_joint6",
|
||||
"idx67_arm_r_joint7",
|
||||
]
|
||||
self.arm_indices = np.array(self.robot.cfg["right_joint_indices"])
|
||||
self.gripper_indices = np.array(self.robot.cfg["right_gripper_indices"])
|
||||
self.reference_prim_path = self.task.robots[self.name].fr_base_path
|
||||
self.lr_name = "right"
|
||||
self._gripper_state = 1.0 if self.robot.right_gripper_state == 1.0 else -1.0
|
||||
else:
|
||||
raise NotImplementedError("robot_file must contain 'left' or 'right'")
|
||||
self.raw_js_names = list(self.cmd_js_names)
|
||||
self._gripper_joint_position = np.array([1.0])
|
||||
|
||||
def get_gripper_action(self):
|
||||
return np.clip(self._gripper_state * self._gripper_joint_position, 0.0, 1.0)
|
||||
|
||||
def _get_sort_path_weights(self):
|
||||
"""Genie1: weight joints 4 and 5 (index 4,5) by 3.0 for path selection."""
|
||||
return [1.0, 1.0, 1.0, 1.0, 3.0, 3.0, 1.0]
|
||||
|
||||
def mobile_move(self, target: np.ndarray, joint_indices: np.ndarray = None, initial_position: np.ndarray = None):
|
||||
raise NotImplementedError
|
||||
50
workflows/simbox/core/controllers/lift2_controller.py
Normal file
50
workflows/simbox/core/controllers/lift2_controller.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Lift2 mobile manipulator controller – template-based."""
|
||||
|
||||
import numpy as np
|
||||
from core.controllers.base_controller import register_controller
|
||||
from core.controllers.template_controller import TemplateController
|
||||
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
@register_controller
|
||||
class Lift2Controller(TemplateController):
|
||||
def _get_default_ignore_substring(self):
|
||||
return ["material", "Plane", "conveyor", "scene", "table", "fluid"]
|
||||
|
||||
def _configure_joint_indices(self, robot_file: str) -> None:
|
||||
self.raw_js_names = ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"]
|
||||
if "left" in robot_file:
|
||||
self.cmd_js_names = ["fl_joint1", "fl_joint2", "fl_joint3", "fl_joint4", "fl_joint5", "fl_joint6"]
|
||||
self.arm_indices = np.array(self.robot.cfg["left_joint_indices"])
|
||||
self.gripper_indices = np.array(self.robot.cfg["left_gripper_indices"])
|
||||
self.reference_prim_path = self.task.robots[self.name].fl_base_path
|
||||
self.lr_name = "left"
|
||||
self._gripper_state = 1.0 if self.robot.left_gripper_state == 1.0 else -1.0
|
||||
elif "right" in robot_file:
|
||||
self.cmd_js_names = ["fr_joint1", "fr_joint2", "fr_joint3", "fr_joint4", "fr_joint5", "fr_joint6"]
|
||||
self.arm_indices = np.array(self.robot.cfg["right_joint_indices"])
|
||||
self.gripper_indices = np.array(self.robot.cfg["right_gripper_indices"])
|
||||
self.reference_prim_path = self.task.robots[self.name].fr_base_path
|
||||
self.lr_name = "right"
|
||||
self._gripper_state = 1.0 if self.robot.right_gripper_state == 1.0 else -1.0
|
||||
else:
|
||||
raise NotImplementedError("robot_file must contain 'left' or 'right'")
|
||||
self._gripper_joint_position = np.array([1.0])
|
||||
|
||||
def get_gripper_action(self):
|
||||
return np.clip(self._gripper_state * self._gripper_joint_position, 0.0, 0.1)
|
||||
|
||||
def forward(self, manip_cmd, eps=5e-3):
|
||||
ee_trans, ee_ori = manip_cmd[0:2]
|
||||
gripper_fn = manip_cmd[2]
|
||||
params = manip_cmd[3]
|
||||
assert hasattr(self, gripper_fn)
|
||||
method = getattr(self, gripper_fn)
|
||||
if gripper_fn in ["in_plane_rotation", "mobile_move", "dummy_forward", "joint_ctrl"]:
|
||||
return method(**params)
|
||||
elif gripper_fn in ["update_pose_cost_metric", "update_specific"]:
|
||||
method(**params)
|
||||
return self.ee_forward(ee_trans, ee_ori, eps=eps, skip_plan=True)
|
||||
else:
|
||||
method(**params)
|
||||
return self.ee_forward(ee_trans, ee_ori, eps=eps)
|
||||
50
workflows/simbox/core/controllers/splitaloha_controller.py
Normal file
50
workflows/simbox/core/controllers/splitaloha_controller.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""SplitAloha dual-arm controller – template-based."""
|
||||
|
||||
import numpy as np
|
||||
from core.controllers.base_controller import register_controller
|
||||
from core.controllers.template_controller import TemplateController
|
||||
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
@register_controller
|
||||
class SplitAlohaController(TemplateController):
|
||||
def _get_default_ignore_substring(self):
|
||||
return ["material", "Plane", "conveyor", "scene", "table", "fluid"]
|
||||
|
||||
def _configure_joint_indices(self, robot_file: str) -> None:
|
||||
self.raw_js_names = ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"]
|
||||
if "left" in robot_file:
|
||||
self.cmd_js_names = ["fl_joint1", "fl_joint2", "fl_joint3", "fl_joint4", "fl_joint5", "fl_joint6"]
|
||||
self.arm_indices = np.array(self.robot.cfg["left_joint_indices"])
|
||||
self.gripper_indices = np.array(self.robot.cfg["left_gripper_indices"])
|
||||
self.reference_prim_path = self.task.robots[self.name].fl_base_path
|
||||
self.lr_name = "left"
|
||||
self._gripper_state = 1.0 if self.robot.left_gripper_state == 1.0 else -1.0
|
||||
elif "right" in robot_file:
|
||||
self.cmd_js_names = ["fr_joint1", "fr_joint2", "fr_joint3", "fr_joint4", "fr_joint5", "fr_joint6"]
|
||||
self.arm_indices = np.array(self.robot.cfg["right_joint_indices"])
|
||||
self.gripper_indices = np.array(self.robot.cfg["right_gripper_indices"])
|
||||
self.reference_prim_path = self.task.robots[self.name].fr_base_path
|
||||
self.lr_name = "right"
|
||||
self._gripper_state = 1.0 if self.robot.right_gripper_state == 1.0 else -1.0
|
||||
else:
|
||||
raise NotImplementedError("robot_file must contain 'left' or 'right'")
|
||||
self._gripper_joint_position = np.array([1.0])
|
||||
|
||||
def get_gripper_action(self):
|
||||
return np.clip(self._gripper_state * self._gripper_joint_position, 0.0, 0.1)
|
||||
|
||||
def forward(self, manip_cmd, eps=5e-3):
|
||||
ee_trans, ee_ori = manip_cmd[0:2]
|
||||
gripper_fn = manip_cmd[2]
|
||||
params = manip_cmd[3]
|
||||
assert hasattr(self, gripper_fn)
|
||||
method = getattr(self, gripper_fn)
|
||||
if gripper_fn in ["in_plane_rotation", "mobile_move", "dummy_forward", "joint_ctrl"]:
|
||||
return method(**params)
|
||||
elif gripper_fn in ["update_pose_cost_metric", "update_specific"]:
|
||||
method(**params)
|
||||
return self.ee_forward(ee_trans, ee_ori, eps=eps, skip_plan=True)
|
||||
else:
|
||||
method(**params)
|
||||
return self.ee_forward(ee_trans, ee_ori, eps=eps)
|
||||
599
workflows/simbox/core/controllers/template_controller.py
Normal file
599
workflows/simbox/core/controllers/template_controller.py
Normal file
@@ -0,0 +1,599 @@
|
||||
"""
|
||||
Template Controller base class for robot motion planning.
|
||||
|
||||
Common functionality extracted from FR3, FrankaRobotiq85, Genie1, Lift2, SplitAloha.
|
||||
Subclasses implement _get_default_ignore_substring() and _configure_joint_indices().
|
||||
"""
|
||||
|
||||
import random
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from core.utils.constants import CUROBO_BATCH_SIZE
|
||||
from core.utils.plan_utils import (
|
||||
filter_paths_by_position_error,
|
||||
filter_paths_by_rotation_error,
|
||||
sort_by_difference_js,
|
||||
)
|
||||
from curobo.cuda_robot_model.cuda_robot_model import CudaRobotModel
|
||||
from curobo.geom.sdf.world import CollisionCheckerType
|
||||
from curobo.geom.sphere_fit import SphereFitType
|
||||
from curobo.geom.types import WorldConfig
|
||||
from curobo.types.base import TensorDeviceType
|
||||
from curobo.types.math import Pose
|
||||
from curobo.types.robot import JointState, RobotConfig
|
||||
from curobo.util.usd_helper import UsdHelper
|
||||
from curobo.util_file import get_world_configs_path, join_path, load_yaml
|
||||
from curobo.wrap.reacher.ik_solver import IKSolver, IKSolverConfig
|
||||
from curobo.wrap.reacher.motion_gen import (
|
||||
MotionGen,
|
||||
MotionGenConfig,
|
||||
MotionGenPlanConfig,
|
||||
PoseCostMetric,
|
||||
)
|
||||
from omni.isaac.core import World
|
||||
from omni.isaac.core.controllers import BaseController
|
||||
from omni.isaac.core.tasks import BaseTask
|
||||
from omni.isaac.core.utils.prims import get_prim_at_path
|
||||
from omni.isaac.core.utils.transformations import (
|
||||
get_relative_transform,
|
||||
pose_from_tf_matrix,
|
||||
)
|
||||
from omni.isaac.core.utils.types import ArticulationAction
|
||||
|
||||
|
||||
# pylint: disable=line-too-long,unused-argument
|
||||
class TemplateController(BaseController):
|
||||
"""Base controller for CuRobo-based motion planning. Supports single and batch planning."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
robot_file: str,
|
||||
task: BaseTask,
|
||||
world: World,
|
||||
constrain_grasp_approach: bool = False,
|
||||
collision_activation_distance: float = 0.03,
|
||||
ignore_substring: Optional[List[str]] = None,
|
||||
use_batch: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(name=name)
|
||||
self.name = name
|
||||
self.world = world
|
||||
self.task = task
|
||||
self.robot = self.task.robots[name]
|
||||
self.ignore_substring = self._get_default_ignore_substring()
|
||||
if ignore_substring is not None:
|
||||
self.ignore_substring = ignore_substring
|
||||
self.ignore_substring.append(name)
|
||||
self.use_batch = use_batch
|
||||
self.constrain_grasp_approach = constrain_grasp_approach
|
||||
self.collision_activation_distance = collision_activation_distance
|
||||
self.usd_help = UsdHelper()
|
||||
self.tensor_args = TensorDeviceType()
|
||||
self.init_curobo = False
|
||||
self.robot_file = robot_file
|
||||
self.num_plan_failed = 0
|
||||
self.raw_js_names = []
|
||||
self.cmd_js_names = []
|
||||
self.arm_indices = np.array([])
|
||||
self.gripper_indices = np.array([])
|
||||
self.reference_prim_path = None
|
||||
self.lr_name = None
|
||||
self._ee_trans = 0.0
|
||||
self._ee_ori = 0.0
|
||||
self._gripper_state = 1.0
|
||||
self._gripper_joint_position = np.array([1.0])
|
||||
self.idx_list = None
|
||||
|
||||
self._configure_joint_indices(robot_file)
|
||||
self._load_robot(robot_file)
|
||||
self._load_kin_model()
|
||||
self._load_world()
|
||||
self._init_motion_gen()
|
||||
|
||||
self.usd_help.load_stage(self.world.stage)
|
||||
self.cmd_plan = None
|
||||
self.cmd_idx = 0
|
||||
self._step_idx = 0
|
||||
self.num_last_cmd = 0
|
||||
self.ds_ratio = 1
|
||||
|
||||
def _get_default_ignore_substring(self) -> List[str]:
|
||||
return ["material", "Plane", "conveyor", "scene", "table"]
|
||||
|
||||
def _configure_joint_indices(self, robot_file: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def _load_robot(self, robot_file: str) -> None:
|
||||
self.robot_cfg = load_yaml(robot_file)["robot_cfg"]
|
||||
|
||||
def _load_kin_model(self) -> None:
|
||||
urdf_file = self.robot_cfg["kinematics"]["urdf_path"]
|
||||
base_link = self.robot_cfg["kinematics"]["base_link"]
|
||||
ee_link = self.robot_cfg["kinematics"]["ee_link"]
|
||||
robot_cfg = RobotConfig.from_basic(urdf_file, base_link, ee_link, self.tensor_args)
|
||||
self.kin_model = CudaRobotModel(robot_cfg.kinematics)
|
||||
|
||||
def _load_world(self, use_default: bool = True) -> None:
|
||||
if use_default:
|
||||
self.world_cfg = WorldConfig()
|
||||
else:
|
||||
world_cfg_table = WorldConfig.from_dict(
|
||||
load_yaml(join_path(get_world_configs_path(), "collision_table.yml"))
|
||||
)
|
||||
self._world_cfg_table = world_cfg_table
|
||||
self._world_cfg_table.cuboid[0].pose[2] -= 10.5
|
||||
world_cfg1 = WorldConfig.from_dict(
|
||||
load_yaml(join_path(get_world_configs_path(), "collision_table.yml"))
|
||||
).get_mesh_world()
|
||||
world_cfg1.mesh[0].name += "_mesh"
|
||||
world_cfg1.mesh[0].pose[2] = -10.5
|
||||
self.world_cfg = WorldConfig(cuboid=world_cfg_table.cuboid, mesh=world_cfg1.mesh)
|
||||
|
||||
def _get_motion_gen_collision_cache(self):
|
||||
"""Override in subclasses to use different cache size (e.g. FR3 uses 1000)."""
|
||||
return {"obb": 700, "mesh": 700}
|
||||
|
||||
def _get_grasp_approach_linear_axis(self) -> int:
|
||||
"""Axis for grasp approach constraint (0=x, 1=y, 2=z). Override in subclasses (e.g. Lift2 uses 0)."""
|
||||
if self.robot.cfg["ee_axis"] == "x":
|
||||
return 0
|
||||
elif self.robot.cfg["ee_axis"] == "y":
|
||||
return 1
|
||||
elif self.robot.cfg["ee_axis"] == "z":
|
||||
return 2
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_sort_path_weights(self) -> Optional[List[float]]:
|
||||
"""Optional per-joint weights for sort_by_difference_js.
|
||||
|
||||
Used when selecting among batch paths. None means equal weights.
|
||||
Override in subclasses (e.g. Genie1).
|
||||
"""
|
||||
return None
|
||||
|
||||
def _init_motion_gen(self) -> None:
|
||||
pose_metric = None
|
||||
if self.constrain_grasp_approach:
|
||||
pose_metric = PoseCostMetric.create_grasp_approach_metric(
|
||||
offset_position=0.1,
|
||||
linear_axis=self._get_grasp_approach_linear_axis(),
|
||||
)
|
||||
if self.use_batch:
|
||||
self.plan_config = MotionGenPlanConfig(
|
||||
enable_graph=True,
|
||||
enable_opt=True,
|
||||
need_graph_success=True,
|
||||
enable_graph_attempt=4,
|
||||
max_attempts=4,
|
||||
enable_finetune_trajopt=True,
|
||||
parallel_finetune=True,
|
||||
time_dilation_factor=1.0,
|
||||
)
|
||||
else:
|
||||
self.plan_config = MotionGenPlanConfig(
|
||||
enable_graph=False,
|
||||
enable_graph_attempt=7,
|
||||
max_attempts=10,
|
||||
pose_cost_metric=pose_metric,
|
||||
enable_finetune_trajopt=True,
|
||||
time_dilation_factor=1.0,
|
||||
)
|
||||
motion_gen_config = MotionGenConfig.load_from_robot_config(
|
||||
self.robot_cfg,
|
||||
self.world_cfg,
|
||||
self.tensor_args,
|
||||
interpolation_dt=0.01,
|
||||
collision_activation_distance=self.collision_activation_distance,
|
||||
trajopt_tsteps=32,
|
||||
collision_checker_type=CollisionCheckerType.MESH,
|
||||
use_cuda_graph=True,
|
||||
self_collision_check=True,
|
||||
collision_cache=self._get_motion_gen_collision_cache(),
|
||||
num_trajopt_seeds=12,
|
||||
num_graph_seeds=12,
|
||||
optimize_dt=True,
|
||||
trajopt_dt=None,
|
||||
trim_steps=None,
|
||||
project_pose_to_goal_frame=False,
|
||||
)
|
||||
ik_config = IKSolverConfig.load_from_robot_config(
|
||||
self.robot_cfg,
|
||||
self.world_cfg,
|
||||
rotation_threshold=0.05,
|
||||
position_threshold=0.005,
|
||||
num_seeds=20,
|
||||
self_collision_check=True,
|
||||
self_collision_opt=True,
|
||||
tensor_args=self.tensor_args,
|
||||
use_cuda_graph=True,
|
||||
collision_checker_type=CollisionCheckerType.MESH,
|
||||
collision_cache={"obb": 700, "mesh": 700},
|
||||
)
|
||||
self.ik_solver = IKSolver(ik_config)
|
||||
self.motion_gen = MotionGen(motion_gen_config)
|
||||
print("warming up..")
|
||||
if self.use_batch:
|
||||
self.motion_gen.warmup(parallel_finetune=True, batch=CUROBO_BATCH_SIZE)
|
||||
else:
|
||||
self.motion_gen.warmup(enable_graph=True, warmup_js_trajopt=False)
|
||||
self.world_model = self.motion_gen.world_collision
|
||||
self.motion_gen.clear_world_cache()
|
||||
self.motion_gen.reset(reset_seed=False)
|
||||
self.motion_gen.update_world(self.world_cfg)
|
||||
|
||||
def update_pose_cost_metric(self, hold_vec_weight: Optional[List[float]] = None) -> None:
|
||||
# reference: https://curobo.org/advanced_examples/3_constrained_planning.html
|
||||
# [angular-x, angular-y, angular-z, linear-x, linear-y, linear-z]
|
||||
# For example,
|
||||
# when hold_vec_weight is None, the corresponding list is [0, 0, 0, 0, 0, 0],
|
||||
# there is no cost added in any directions.
|
||||
# When hold_vec_weight = [1, 1, 1, 0, 0, 0], the tool orientation is holed.
|
||||
# assert hold_vec_weight is None or len(hold_vec_weight) == 6
|
||||
if hold_vec_weight:
|
||||
pose_cost_metric = PoseCostMetric(
|
||||
hold_partial_pose=True,
|
||||
hold_vec_weight=self.motion_gen.tensor_args.to_device(hold_vec_weight),
|
||||
)
|
||||
else:
|
||||
pose_cost_metric = None
|
||||
self.plan_config.pose_cost_metric = pose_cost_metric
|
||||
|
||||
def update(self) -> None:
|
||||
obstacles = self.usd_help.get_obstacles_from_stage(
|
||||
ignore_substring=self.ignore_substring, reference_prim_path=self.reference_prim_path
|
||||
).get_collision_check_world()
|
||||
if self.motion_gen is not None:
|
||||
self.motion_gen.update_world(obstacles)
|
||||
self.world_cfg = obstacles
|
||||
|
||||
def reset(self, ignore_substring: Optional[str] = None) -> None:
|
||||
if ignore_substring:
|
||||
self.ignore_substring = ignore_substring
|
||||
self.update()
|
||||
self.init_curobo = True
|
||||
self.cmd_plan = None
|
||||
self.cmd_idx = 0
|
||||
self.num_plan_failed = 0
|
||||
if self.lr_name == "left":
|
||||
self._gripper_state = 1.0 if self.robot.left_gripper_state == 1.0 else -1.0
|
||||
elif self.lr_name == "right":
|
||||
self._gripper_state = 1.0 if self.robot.right_gripper_state == 1.0 else -1.0
|
||||
if self.lr_name == "left":
|
||||
self.robot_ee_path = self.robot.fl_ee_path
|
||||
self.robot_base_path = self.robot.fl_base_path
|
||||
else:
|
||||
self.robot_ee_path = self.robot.fr_ee_path
|
||||
self.robot_base_path = self.robot.fr_base_path
|
||||
self.T_base_ee_init = get_relative_transform(
|
||||
get_prim_at_path(self.robot_ee_path), get_prim_at_path(self.robot_base_path)
|
||||
)
|
||||
self.T_world_base_init = get_relative_transform(
|
||||
get_prim_at_path(self.robot_base_path), get_prim_at_path(self.task.root_prim_path)
|
||||
)
|
||||
self.T_world_ee_init = self.T_world_base_init @ self.T_base_ee_init
|
||||
self._ee_trans, self._ee_ori = self.get_ee_pose()
|
||||
self._ee_trans = self.tensor_args.to_device(self._ee_trans)
|
||||
self._ee_ori = self.tensor_args.to_device(self._ee_ori)
|
||||
self.update_pose_cost_metric()
|
||||
|
||||
def plan_batch(self, ee_translation_goal_batch, ee_orientation_goal_batch, sim_js, js_names):
|
||||
t1 = time.time()
|
||||
torch.cuda.synchronize()
|
||||
sim_js_positions = (sim_js.positions)[np.newaxis, :]
|
||||
ik_goal = Pose(
|
||||
position=self.tensor_args.to_device(ee_translation_goal_batch),
|
||||
quaternion=self.tensor_args.to_device(ee_orientation_goal_batch),
|
||||
batch=CUROBO_BATCH_SIZE,
|
||||
)
|
||||
cu_js = JointState(
|
||||
position=self.tensor_args.to_device(np.tile(sim_js_positions, (CUROBO_BATCH_SIZE, 1))),
|
||||
velocity=self.tensor_args.to_device(np.tile(sim_js_positions, (CUROBO_BATCH_SIZE, 1))) * 0.0,
|
||||
acceleration=self.tensor_args.to_device(np.tile(sim_js_positions, (CUROBO_BATCH_SIZE, 1))) * 0.0,
|
||||
jerk=self.tensor_args.to_device(np.tile(sim_js_positions, (CUROBO_BATCH_SIZE, 1))) * 0.0,
|
||||
joint_names=js_names,
|
||||
)
|
||||
cu_js = cu_js.get_ordered_joint_state(self.cmd_js_names)
|
||||
result = self.motion_gen.plan_batch(cu_js, ik_goal, self.plan_config.clone())
|
||||
t2 = time.time()
|
||||
torch.cuda.synchronize()
|
||||
print("plan batch duration :", t2 - t1)
|
||||
return result
|
||||
|
||||
def plan(self, ee_translation_goal, ee_orientation_goal, sim_js: JointState, js_names: list):
|
||||
if self.use_batch:
|
||||
ik_goal = Pose(
|
||||
position=self.tensor_args.to_device(ee_translation_goal.unsqueeze(0).expand(CUROBO_BATCH_SIZE, -1)),
|
||||
quaternion=self.tensor_args.to_device(ee_orientation_goal.unsqueeze(0).expand(CUROBO_BATCH_SIZE, -1)),
|
||||
batch=CUROBO_BATCH_SIZE,
|
||||
)
|
||||
cu_js = JointState(
|
||||
position=self.tensor_args.to_device(np.tile((sim_js.positions)[np.newaxis, :], (CUROBO_BATCH_SIZE, 1))),
|
||||
velocity=self.tensor_args.to_device(np.tile((sim_js.positions)[np.newaxis, :], (CUROBO_BATCH_SIZE, 1)))
|
||||
* 0.0,
|
||||
acceleration=self.tensor_args.to_device(
|
||||
np.tile((sim_js.positions)[np.newaxis, :], (CUROBO_BATCH_SIZE, 1))
|
||||
)
|
||||
* 0.0,
|
||||
jerk=self.tensor_args.to_device(np.tile((sim_js.positions)[np.newaxis, :], (CUROBO_BATCH_SIZE, 1)))
|
||||
* 0.0,
|
||||
joint_names=js_names,
|
||||
)
|
||||
cu_js = cu_js.get_ordered_joint_state(self.cmd_js_names)
|
||||
return self.motion_gen.plan_batch(cu_js, ik_goal, self.plan_config.clone())
|
||||
ik_goal = Pose(
|
||||
position=self.tensor_args.to_device(ee_translation_goal),
|
||||
quaternion=self.tensor_args.to_device(ee_orientation_goal),
|
||||
)
|
||||
cu_js = JointState(
|
||||
position=self.tensor_args.to_device(sim_js.positions),
|
||||
velocity=self.tensor_args.to_device(sim_js.velocities) * 0.0,
|
||||
acceleration=self.tensor_args.to_device(sim_js.velocities) * 0.0,
|
||||
jerk=self.tensor_args.to_device(sim_js.velocities) * 0.0,
|
||||
joint_names=js_names,
|
||||
)
|
||||
cu_js = cu_js.get_ordered_joint_state(self.cmd_js_names)
|
||||
return self.motion_gen.plan_single(cu_js.unsqueeze(0), ik_goal, self.plan_config.clone())
|
||||
|
||||
def forward(self, manip_cmd, eps=5e-3):
|
||||
ee_trans, ee_ori = manip_cmd[0:2]
|
||||
gripper_fn = manip_cmd[2]
|
||||
params = manip_cmd[3]
|
||||
assert hasattr(self, gripper_fn)
|
||||
method = getattr(self, gripper_fn)
|
||||
if gripper_fn in ["in_plane_rotation", "mobile_move", "dummy_forward"]:
|
||||
return method(**params)
|
||||
elif gripper_fn in ["update_pose_cost_metric", "update_specific"]:
|
||||
method(**params)
|
||||
return self.ee_forward(ee_trans, ee_ori, eps=eps, skip_plan=True)
|
||||
else:
|
||||
method(**params)
|
||||
return self.ee_forward(ee_trans, ee_ori, eps)
|
||||
|
||||
def ee_forward(
|
||||
self,
|
||||
ee_trans: torch.Tensor | np.ndarray,
|
||||
ee_ori: torch.Tensor | np.ndarray,
|
||||
eps=1e-4,
|
||||
skip_plan=False,
|
||||
):
|
||||
ee_trans = self.tensor_args.to_device(ee_trans)
|
||||
ee_ori = self.tensor_args.to_device(ee_ori)
|
||||
sim_js = self.robot.get_joints_state()
|
||||
js_names = self.robot.dof_names
|
||||
plan_flag = torch.logical_or(
|
||||
torch.norm(self._ee_trans - ee_trans) > eps,
|
||||
torch.norm(self._ee_ori - ee_ori) > eps,
|
||||
)
|
||||
if not skip_plan:
|
||||
if plan_flag:
|
||||
self.cmd_idx = 0
|
||||
self._step_idx = 0
|
||||
self.num_last_cmd = 0
|
||||
result = self.plan(ee_trans, ee_ori, sim_js, js_names)
|
||||
if self.use_batch:
|
||||
if result.success.any():
|
||||
self._ee_trans = ee_trans
|
||||
self._ee_ori = ee_ori
|
||||
paths = result.get_successful_paths()
|
||||
position_filter_res = filter_paths_by_position_error(
|
||||
paths, result.position_error[result.success]
|
||||
)
|
||||
rotation_filter_res = filter_paths_by_rotation_error(
|
||||
paths, result.rotation_error[result.success]
|
||||
)
|
||||
filtered_paths = [
|
||||
p for i, p in enumerate(paths) if position_filter_res[i] and rotation_filter_res[i]
|
||||
]
|
||||
if len(filtered_paths) == 0:
|
||||
filtered_paths = paths
|
||||
sort_weights = self._get_sort_path_weights() # pylint: disable=assignment-from-none
|
||||
weights_arg = self.tensor_args.to_device(sort_weights) if sort_weights is not None else None
|
||||
sorted_indices = sort_by_difference_js(filtered_paths, weights=weights_arg)
|
||||
cmd_plan = self.motion_gen.get_full_js(paths[sorted_indices[0]])
|
||||
self.idx_list = list(range(len(self.raw_js_names)))
|
||||
self.cmd_plan = cmd_plan.get_ordered_joint_state(self.raw_js_names)
|
||||
self.num_plan_failed = 0
|
||||
else:
|
||||
print("Plan did not converge to a solution.")
|
||||
self.num_plan_failed += 1
|
||||
else:
|
||||
succ = result.success.item()
|
||||
if succ:
|
||||
self._ee_trans = ee_trans
|
||||
self._ee_ori = ee_ori
|
||||
cmd_plan = result.get_interpolated_plan()
|
||||
self.idx_list = list(range(len(self.raw_js_names)))
|
||||
self.cmd_plan = cmd_plan.get_ordered_joint_state(self.raw_js_names)
|
||||
self.num_plan_failed = 0
|
||||
else:
|
||||
print("Plan did not converge to a solution.")
|
||||
self.num_plan_failed += 1
|
||||
if self.cmd_plan and self._step_idx % 1 == 0:
|
||||
cmd_state = self.cmd_plan[self.cmd_idx]
|
||||
art_action = ArticulationAction(
|
||||
cmd_state.position.cpu().numpy(),
|
||||
cmd_state.velocity.cpu().numpy() * 0.0,
|
||||
joint_indices=self.idx_list,
|
||||
)
|
||||
self.cmd_idx += self.ds_ratio
|
||||
if self.cmd_idx >= len(self.cmd_plan):
|
||||
self.cmd_idx = 0
|
||||
self.cmd_plan = None
|
||||
else:
|
||||
self.num_last_cmd += 1
|
||||
art_action = ArticulationAction(joint_positions=sim_js.positions[self.arm_indices])
|
||||
else:
|
||||
art_action = ArticulationAction(joint_positions=sim_js.positions[self.arm_indices])
|
||||
self._step_idx += 1
|
||||
arm_action = art_action.joint_positions
|
||||
gripper_action = self.get_gripper_action()
|
||||
joint_positions = np.concatenate([arm_action, gripper_action])
|
||||
self._action = {
|
||||
"joint_positions": joint_positions,
|
||||
"joint_indices": np.concatenate([self.arm_indices, self.gripper_indices]),
|
||||
"lr_name": self.lr_name,
|
||||
"arm_action": arm_action,
|
||||
"gripper_action": gripper_action,
|
||||
}
|
||||
return self._action
|
||||
|
||||
def get_gripper_action(self):
|
||||
return np.clip(self._gripper_state * self._gripper_joint_position, 0.0, 0.04)
|
||||
|
||||
def get_ee_pose(self):
|
||||
sim_js = self.robot.get_joints_state()
|
||||
q_state = torch.tensor(sim_js.positions[self.arm_indices], **self.tensor_args.as_torch_dict()).reshape(1, -1)
|
||||
ee_pose = self.kin_model.get_state(q_state)
|
||||
return ee_pose.ee_position[0].cpu().numpy(), ee_pose.ee_quaternion[0].cpu().numpy()
|
||||
|
||||
def get_armbase_pose(self):
|
||||
armbase_pose = get_relative_transform(
|
||||
get_prim_at_path(self.robot_base_path), get_prim_at_path(self.task.root_prim_path)
|
||||
)
|
||||
return pose_from_tf_matrix(armbase_pose)
|
||||
|
||||
def forward_kinematic(self, q_state: np.ndarray):
|
||||
q_state = q_state.reshape(1, -1)
|
||||
q_state = self.tensor_args.to_device(q_state)
|
||||
out = self.kin_model.get_state(q_state)
|
||||
return out.ee_position[0].cpu().numpy(), out.ee_quaternion[0].cpu().numpy()
|
||||
|
||||
def close_gripper(self):
|
||||
self._gripper_state = -1.0
|
||||
|
||||
def open_gripper(self):
|
||||
self._gripper_state = 1.0
|
||||
|
||||
def attach_obj(self, obj_prim_path: str, link_name="attached_object"):
|
||||
sim_js = self.robot.get_joints_state()
|
||||
js_names = self.robot.dof_names
|
||||
cu_js = JointState(
|
||||
position=self.tensor_args.to_device(sim_js.positions),
|
||||
velocity=self.tensor_args.to_device(sim_js.velocities) * 0.0,
|
||||
acceleration=self.tensor_args.to_device(sim_js.velocities) * 0.0,
|
||||
jerk=self.tensor_args.to_device(sim_js.velocities) * 0.0,
|
||||
joint_names=js_names,
|
||||
)
|
||||
self.motion_gen.attach_objects_to_robot(
|
||||
cu_js,
|
||||
[obj_prim_path],
|
||||
link_name=link_name,
|
||||
sphere_fit_type=SphereFitType.VOXEL_VOLUME_SAMPLE_SURFACE,
|
||||
world_objects_pose_offset=Pose.from_list([0, 0, 0.01, 1, 0, 0, 0], self.tensor_args),
|
||||
)
|
||||
|
||||
def detach_obj(self):
|
||||
self.motion_gen.detach_object_from_robot()
|
||||
|
||||
def update_specific(self, ignore_substring, reference_prim_path):
|
||||
obstacles = self.usd_help.get_obstacles_from_stage(
|
||||
ignore_substring=ignore_substring, reference_prim_path=reference_prim_path
|
||||
).get_collision_check_world()
|
||||
if self.motion_gen is not None:
|
||||
self.motion_gen.update_world(obstacles)
|
||||
self.world_cfg = obstacles
|
||||
|
||||
def test_single_ik(self, ee_trans, ee_ori):
|
||||
assert not self.use_batch
|
||||
ik_goal = Pose(position=self.tensor_args.to_device(ee_trans), quaternion=self.tensor_args.to_device(ee_ori))
|
||||
result = self.ik_solver.solve_single(ik_goal)
|
||||
succ = result.success.item()
|
||||
if succ: # pylint: disable=simplifiable-if-statement
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def test_batch_forward(self, ee_trans_batch_np, ee_ori_batch_np):
|
||||
ee_trans_batch = self.tensor_args.to_device(ee_trans_batch_np)
|
||||
ee_ori_batch = self.tensor_args.to_device(ee_ori_batch_np)
|
||||
sim_js = self.robot.get_joints_state()
|
||||
js_names = self.robot.dof_names
|
||||
result = self.plan_batch(ee_trans_batch, ee_ori_batch, sim_js, js_names)
|
||||
|
||||
return result
|
||||
|
||||
def test_single_forward(self, ee_trans: np.ndarray, ee_ori: np.ndarray):
|
||||
assert ee_trans is not None and ee_ori is not None
|
||||
sim_js = self.robot.get_joints_state()
|
||||
js_names = self.robot.dof_names
|
||||
result = self.plan(ee_trans, ee_ori, sim_js, js_names)
|
||||
succ = result.success.item()
|
||||
if succ:
|
||||
print("Success")
|
||||
return 1
|
||||
print("Plan did not converge to a solution.")
|
||||
return 0
|
||||
|
||||
def pre_forward(self, ee_trans: np.ndarray, ee_ori: np.ndarray, expected_js=None, ds_ratio=1):
|
||||
assert ee_trans is not None and ee_ori is not None
|
||||
ee_trans = self.tensor_args.to_device(ee_trans)
|
||||
ee_ori = self.tensor_args.to_device(ee_ori)
|
||||
sim_js = self.robot.get_joints_state()
|
||||
js_names = self.robot.dof_names
|
||||
if expected_js is not None:
|
||||
sim_js.positions[self.arm_indices] = expected_js
|
||||
result = self.plan(ee_trans, ee_ori, sim_js, js_names)
|
||||
if self.use_batch:
|
||||
if result.success.any():
|
||||
print("Success")
|
||||
cmd_plans = result.get_successful_paths()
|
||||
cmd_plan = random.choice(cmd_plans)
|
||||
cmd_plan = self.motion_gen.get_full_js(cmd_plan)
|
||||
cmd_plan = cmd_plan.get_ordered_joint_state(self.raw_js_names)
|
||||
N = cmd_plan.shape[0]
|
||||
dt = self.motion_gen.interpolation_dt
|
||||
self.ds_ratio = ds_ratio
|
||||
cmd_time = N * dt / self.plan_config.time_dilation_factor / self.ds_ratio
|
||||
return cmd_time, np.array(cmd_plan[-1].position.cpu())
|
||||
print("Plan did not converge to a solution.")
|
||||
self.num_plan_failed = 1000
|
||||
return 0, expected_js
|
||||
succ = result.success.item()
|
||||
if succ:
|
||||
print("Success")
|
||||
cmd_plan = result.get_interpolated_plan()
|
||||
N = cmd_plan.shape[0]
|
||||
dt = self.motion_gen.interpolation_dt
|
||||
self.ds_ratio = ds_ratio
|
||||
cmd_time = N * dt / self.plan_config.time_dilation_factor / self.ds_ratio
|
||||
return cmd_time, np.array(cmd_plan[-1].position.cpu())
|
||||
print("Plan did not converge to a solution.")
|
||||
self.num_plan_failed = 1000
|
||||
return 0, expected_js
|
||||
|
||||
def in_plane_rotation(self, target_rotate: np.ndarray):
|
||||
action = deepcopy(self._action)
|
||||
last_arm = len(self.arm_indices) - 1
|
||||
action["joint_positions"][last_arm] -= target_rotate
|
||||
action["arm_action"][last_arm] -= target_rotate
|
||||
return action
|
||||
|
||||
def mobile_move(self, target: np.ndarray, joint_indices: np.ndarray = None, initial_position: np.ndarray = None):
|
||||
return {
|
||||
"joint_positions": initial_position + target,
|
||||
"joint_indices": np.array(joint_indices),
|
||||
"lr_name": "whole",
|
||||
}
|
||||
|
||||
def dummy_forward(self, arm_action, gripper_state, *args, **kwargs):
|
||||
if gripper_state == 1.0:
|
||||
self.open_gripper()
|
||||
elif gripper_state == -1.0:
|
||||
self.close_gripper()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
gripper_action = self.get_gripper_action()
|
||||
return {
|
||||
"joint_positions": np.concatenate([arm_action, gripper_action]),
|
||||
"joint_indices": np.concatenate([self.arm_indices, self.gripper_indices]),
|
||||
"lr_name": self.lr_name,
|
||||
"arm_action": arm_action,
|
||||
"gripper_action": gripper_action,
|
||||
}
|
||||
Reference in New Issue
Block a user