Change config logic in:
- gym_manipulator - find_joint_limits - end_effector_utils
This commit is contained in:
committed by
AdilZouitine
parent
ee25fd8afe
commit
b7b6d8102f
@@ -40,13 +40,18 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
if "images" not in key:
|
||||
continue
|
||||
|
||||
for imgkey, img in imgs.items():
|
||||
# TODO(aliberts, rcadene): use transforms.ToTensor()?
|
||||
# TODO(aliberts, rcadene): use transforms.ToTensor()?
|
||||
if not torch.is_tensor(img):
|
||||
img = torch.from_numpy(img)
|
||||
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
if img.ndim == 3:
|
||||
img = img.unsqueeze(0)
|
||||
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, (
|
||||
f"expect channel last images, but instead got {img.shape=}"
|
||||
)
|
||||
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
|
||||
@@ -87,6 +87,8 @@ class RecordControlConfig(ControlConfig):
|
||||
play_sounds: bool = True
|
||||
# Resume recording on an existing dataset.
|
||||
resume: bool = False
|
||||
# Reset follower arms to an initial configuration.
|
||||
reset_follower_arms: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
|
||||
@@ -221,7 +221,7 @@ def record_episode(
|
||||
events=events,
|
||||
policy=policy,
|
||||
fps=fps,
|
||||
record_delta_actions=record_delta_actions,
|
||||
# record_delta_actions=record_delta_actions,
|
||||
teleoperate=policy is None,
|
||||
single_task=single_task,
|
||||
)
|
||||
@@ -267,8 +267,8 @@ def control_loop(
|
||||
|
||||
if teleoperate:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
if record_delta_actions:
|
||||
action["action"] = action["action"] - current_joint_positions
|
||||
# if record_delta_actions:
|
||||
# action["action"] = action["action"] - current_joint_positions
|
||||
else:
|
||||
observation = robot.capture_observation()
|
||||
|
||||
|
||||
@@ -443,7 +443,7 @@ class So100RobotConfig(ManipulatorRobotConfig):
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem58760431091",
|
||||
port="/dev/tty.usbmodem58760433331",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
@@ -460,7 +460,7 @@ class So100RobotConfig(ManipulatorRobotConfig):
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0076891",
|
||||
port="/dev/tty.usbmodem58760431631",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
|
||||
@@ -475,12 +475,12 @@ class ManipulatorRobot:
|
||||
goal_pos = leader_pos[name]
|
||||
|
||||
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
|
||||
if self.config.joint_position_relative_bounds is not None:
|
||||
goal_pos = torch.clamp(
|
||||
goal_pos,
|
||||
self.config.joint_position_relative_bounds["min"],
|
||||
self.config.joint_position_relative_bounds["max"],
|
||||
)
|
||||
# if self.config.joint_position_relative_bounds is not None:
|
||||
# goal_pos = torch.clamp(
|
||||
# goal_pos,
|
||||
# self.config.joint_position_relative_bounds["min"],
|
||||
# self.config.joint_position_relative_bounds["max"],
|
||||
# )
|
||||
|
||||
# Cap goal position when too far away from present position.
|
||||
# Slower fps expected due to reading from the follower.
|
||||
@@ -604,12 +604,12 @@ class ManipulatorRobot:
|
||||
from_idx = to_idx
|
||||
|
||||
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
|
||||
if self.config.joint_position_relative_bounds is not None:
|
||||
goal_pos = torch.clamp(
|
||||
goal_pos,
|
||||
self.config.joint_position_relative_bounds["min"],
|
||||
self.config.joint_position_relative_bounds["max"],
|
||||
)
|
||||
# if self.config.joint_position_relative_bounds is not None:
|
||||
# goal_pos = torch.clamp(
|
||||
# goal_pos,
|
||||
# self.config.joint_position_relative_bounds["min"],
|
||||
# self.config.joint_position_relative_bounds["max"],
|
||||
# )
|
||||
|
||||
# Cap goal position when too far away from present position.
|
||||
# Slower fps expected due to reading from the follower.
|
||||
|
||||
Reference in New Issue
Block a user