Compare commits

..

3 Commits

Author SHA1 Message Date
Michel Aractingi
16b905ee67 adding sac implemenation 2024-12-06 01:32:22 +01:00
Michel Aractingi
44d96a0811 nit 2024-11-27 15:19:20 +01:00
Michel Aractingi
4488e55e94 first commit 2024-11-27 15:01:06 +01:00
19 changed files with 1570 additions and 2528 deletions

View File

@@ -131,8 +131,7 @@ def encode_video_frames(
imgs_dir: Path | str,
video_path: Path | str,
fps: int,
# vcodec: str = "libsvtav1",
vcodec: str = "libx264",
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
g: int | None = 2,
crf: int | None = 30,

View File

@@ -140,25 +140,25 @@ class ACTPolicy(
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
bsize = actions_hat.shape[0]
l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none")
l1_loss = l1_loss * ~batch["action_is_pad"].unsqueeze(-1)
l1_loss = l1_loss.view(bsize, -1).mean(dim=1)
l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
).mean()
out_dict = {}
out_dict["l1_loss"] = l1_loss
loss_dict = {"l1_loss": l1_loss.item()}
if self.config.use_vae:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total
# KL-divergence per batch element, then take the mean over the batch.
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
kld_loss = (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1)
out_dict["loss"] = l1_loss + kld_loss * self.config.kl_weight
mean_kld = (
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
)
loss_dict["kld_loss"] = mean_kld.item()
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
else:
out_dict["loss"] = l1_loss
loss_dict["loss"] = l1_loss
out_dict["action"] = self.unnormalize_outputs({"action": actions_hat})["action"]
return out_dict
return loss_dict
class ACTTemporalEnsembler:

View File

@@ -0,0 +1,23 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
@dataclass
class HILSerlConfig:
pass

View File

@@ -0,0 +1,30 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from huggingface_hub import PyTorchModelHubMixin
class HILSerlPolicy(
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "hilserl"],
):
pass

View File

@@ -0,0 +1,23 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
@dataclass
class SACConfig:
discount = 0.99

View File

@@ -0,0 +1,156 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import deque
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import Tensor
from huggingface_hub import PyTorchModelHubMixin
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.sac.configuration_sac import SACConfig
class SACPolicy(
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "RL", "SAC"],
):
def __init__(
self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
):
super().__init__()
if config is None:
config = SACConfig()
self.config = config
if config.input_normalization_modes is not None:
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats
)
else:
self.normalize_inputs = nn.Identity()
self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
self.critic_ensemble = ...
self.critic_target = ...
self.actor_network = ...
self.temperature = ...
def reset(self):
"""
Clear observation and action queues. Should be called on `env.reset()`
queues are populated during rollout of the policy, they contain the n latest observations and actions
"""
self._queues = {
"observation.state": deque(maxlen=1),
"action": deque(maxlen=1),
}
if self._use_image:
self._queues["observation.image"] = deque(maxlen=1)
if self._use_env_state:
self._queues["observation.environment_state"] = deque(maxlen=1)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
actions, _ = self.actor_network(batch['observations'])###
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
"""Run the batch through the model and compute the loss.
Returns a dictionary with loss as a tensor, and other information as native floats.
"""
observation_batch =
next_obaservation_batch =
action_batch =
reward_batch =
dones_batch =
# perform image augmentation
# reward bias
# from HIL-SERL code base
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
# calculate critics loss
# 1- compute actions from policy
next_actions = ..
# 2- compute q targets
q_targets = self.target_qs(next_obaservation_batch, next_actions)
# critics subsample size
min_q = q_targets.min(dim=0)
# backup entropy
td_target = reward_batch + self.discount * min_q
# 3- compute predicted qs
q_preds = self.critic_ensemble(observation_batch, action_batch)
# 4- Calculate loss
critics_loss = F.mse_loss(q_preds,
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])) # dones masks
# calculate actors loss
# 1- temperature
temperature = self.temperature()
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
actions, log_probs = self.actor_network(observation_batch)
# 3- get q-value predictions
with torch.no_grad():
q_preds = self.critic_ensemble(observation_batch, actions, return_type="mean")
actor_loss = -(q_preds - temperature * log_probs).mean()
# calculate temperature loss
# 1- calculate entropy
entropy = -log_probs.mean()
temperature_loss = temperature * (entropy - self.target_entropy).mean()
loss = critics_loss + actor_loss + temperature_loss
return {
"Q_value_loss": critics_loss.item(),
"pi_loss": actor_loss.item(),
"temperature_loss": temperature_loss.item(),
"temperature": temperature.item(),
"entropy": entropy.item(),
"loss": loss,
}
def update(self):
self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight)
#for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()):
# target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight)

View File

@@ -1,90 +0,0 @@
"""
Wrapper for Reachy2 camera from sdk
"""
from dataclasses import dataclass, replace
import cv2
import numpy as np
from reachy2_sdk.media.camera import CameraView
from reachy2_sdk.media.camera_manager import CameraManager
@dataclass
class ReachyCameraConfig:
fps: int | None = None
width: int | None = None
height: int | None = None
color_mode: str = "rgb"
rotation: int | None = None
mock: bool = False
def __post_init__(self):
if self.color_mode not in ["rgb", "bgr"]:
raise ValueError(
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
)
self.channels = 3
class ReachyCamera:
def __init__(
self,
host: str,
port: int,
name: str,
image_type: str,
config: ReachyCameraConfig | None = None,
**kwargs,
):
if config is None:
config = ReachyCameraConfig()
# Overwrite config arguments using kwargs
config = replace(config, **kwargs)
self.host = host
self.port = port
self.width = config.width
self.height = config.height
self.channels = config.channels
self.fps = config.fps
self.image_type = image_type
self.name = name
self.config = config
self.cam_manager = None
self.is_connected = False
self.logs = {}
def connect(self):
if not self.is_connected:
self.cam_manager = CameraManager(host=self.host, port=self.port)
self.cam_manager.initialize_cameras() # FIXME: maybe we should not re-initialize
self.is_connected = True
def read(self) -> np.ndarray:
if not self.is_connected:
self.connect()
frame = None
if self.name == "teleop" and hasattr(self.cam_manager, "teleop"):
if self.image_type == "left":
frame = self.cam_manager.teleop.get_frame(CameraView.LEFT)
elif self.image_type == "right":
frame = self.cam_manager.teleop.get_frame(CameraView.RIGHT)
elif self.name == "depth" and hasattr(self.cam_manager, "depth"):
if self.image_type == "depth":
frame = self.cam_manager.depth.get_depth_frame()
elif self.image_type == "rgb":
frame = self.cam_manager.depth.get_frame()
if frame is None:
return None
if frame is not None and self.config.color_mode == "rgb":
img, timestamp = frame
frame = (cv2.cvtColor(img, cv2.COLOR_BGR2RGB), timestamp)
return frame

View File

@@ -46,7 +46,7 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f
log_dt("dt", dt_s)
# TODO(aliberts): move robot-specific logs logic in robot.print_logs()
if not robot.robot_type.lower().startswith(("stretch", "reachy")):
if not robot.robot_type.startswith("stretch"):
for name in robot.leader_arms:
key = f"read_leader_{name}_pos_dt_s"
if key in robot.logs:

View File

@@ -1,317 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The Pollen Robotics team and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from copy import copy
from dataclasses import dataclass, field, replace
import numpy as np
import torch
from reachy2_sdk import ReachySDK
from lerobot.common.robot_devices.cameras.reachy2 import ReachyCamera
REACHY_MOTORS = [
"neck_yaw.pos",
"neck_pitch.pos",
"neck_roll.pos",
"r_shoulder_pitch.pos",
"r_shoulder_roll.pos",
"r_elbow_yaw.pos",
"r_elbow_pitch.pos",
"r_wrist_roll.pos",
"r_wrist_pitch.pos",
"r_wrist_yaw.pos",
"r_gripper.pos",
"l_shoulder_pitch.pos",
"l_shoulder_roll.pos",
"l_elbow_yaw.pos",
"l_elbow_pitch.pos",
"l_wrist_roll.pos",
"l_wrist_pitch.pos",
"l_wrist_yaw.pos",
"l_gripper.pos",
"mobile_base.vx",
"mobile_base.vy",
"mobile_base.vtheta",
]
@dataclass
class ReachyRobotConfig:
robot_type: str | None = "reachy2"
cameras: dict[str, ReachyCamera] = field(default_factory=lambda: {})
ip_address: str | None = "172.17.135.207"
# ip_address: str | None = "192.168.0.197"
# ip_address: str | None = "localhost"
class ReachyRobot:
"""Wrapper of ReachySDK"""
def __init__(self, config: ReachyRobotConfig | None = None, **kwargs):
if config is None:
config = ReachyRobotConfig()
# Overwrite config arguments using kwargs
self.config = replace(config, **kwargs)
self.robot_type = self.config.robot_type
self.cameras = self.config.cameras
self.has_camera = True
self.num_cameras = len(self.cameras)
self.is_connected = False
self.teleop = None
self.logs = {}
self.reachy = None
self.mobile_base_available = False
self.state_keys = None
self.action_keys = None
@property
def camera_features(self) -> dict:
cam_ft = {}
for cam_key, cam in self.cameras.items():
key = f"observation.images.{cam_key}"
cam_ft[key] = {
"shape": (cam.height, cam.width, cam.channels),
"names": ["height", "width", "channels"],
"info": None,
}
return cam_ft
@property
def motor_features(self) -> dict:
motors = REACHY_MOTORS
# if self.mobile_base_available:
# motors += REACHY_MOBILE_BASE
return {
"action": {
"dtype": "float32",
"shape": (len(motors),),
"names": motors,
},
"observation.state": {
"dtype": "float32",
"shape": (len(motors),),
"names": motors,
},
}
@property
def features(self):
return {**self.motor_features, **self.camera_features}
def connect(self) -> None:
self.reachy = ReachySDK(host=self.config.ip_address)
print("Connecting to Reachy")
self.reachy.connect()
self.is_connected = self.reachy.is_connected
if not self.is_connected:
print(
f"Cannot connect to Reachy at address {self.config.ip_address}. Maybe a connection already exists."
)
raise ConnectionError()
# self.reachy.turn_on()
print(self.cameras)
if self.cameras is not None:
for name in self.cameras:
print(f"Connecting camera: {name}")
self.cameras[name].connect()
self.is_connected = self.is_connected and self.cameras[name].is_connected
if not self.is_connected:
print("Could not connect to the cameras, check that all cameras are plugged-in.")
raise ConnectionError()
self.mobile_base_available = self.reachy.mobile_base is not None
def run_calibration(self):
pass
def teleop_step(
self, record_data=False
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
if not record_data:
return
action = {}
action["neck_roll.pos"] = self.reachy.head.neck.roll.goal_position
action["neck_pitch.pos"] = self.reachy.head.neck.pitch.goal_position
action["neck_yaw.pos"] = self.reachy.head.neck.yaw.goal_position
action["r_shoulder_pitch.pos"] = self.reachy.r_arm.shoulder.pitch.goal_position
action["r_shoulder_roll.pos"] = self.reachy.r_arm.shoulder.roll.goal_position
action["r_elbow_yaw.pos"] = self.reachy.r_arm.elbow.yaw.goal_position
action["r_elbow_pitch.pos"] = self.reachy.r_arm.elbow.pitch.goal_position
action["r_wrist_roll.pos"] = self.reachy.r_arm.wrist.roll.goal_position
action["r_wrist_pitch.pos"] = self.reachy.r_arm.wrist.pitch.goal_position
action["r_wrist_yaw.pos"] = self.reachy.r_arm.wrist.yaw.goal_position
action["r_gripper.pos"] = self.reachy.r_arm.gripper.opening
action["l_shoulder_pitch.pos"] = self.reachy.l_arm.shoulder.pitch.goal_position
action["l_shoulder_roll.pos"] = self.reachy.l_arm.shoulder.roll.goal_position
action["l_elbow_yaw.pos"] = self.reachy.l_arm.elbow.yaw.goal_position
action["l_elbow_pitch.pos"] = self.reachy.l_arm.elbow.pitch.goal_position
action["l_wrist_roll.pos"] = self.reachy.l_arm.wrist.roll.goal_position
action["l_wrist_pitch.pos"] = self.reachy.l_arm.wrist.pitch.goal_position
action["l_wrist_yaw.pos"] = self.reachy.l_arm.wrist.yaw.goal_position
action["l_gripper.pos"] = self.reachy.l_arm.gripper.opening
if self.mobile_base_available:
last_cmd_vel = self.reachy.mobile_base.last_cmd_vel
action["mobile_base_x.vel"] = last_cmd_vel["x"]
action["mobile_base_y.vel"] = last_cmd_vel["y"]
action["mobile_base_theta.vel"] = last_cmd_vel["theta"]
else:
action["mobile_base_x.vel"] = 0
action["mobile_base_y.vel"] = 0
action["mobile_base_theta.vel"] = 0
dtype = self.motor_features["action"]["dtype"]
action = np.array(list(action.values()), dtype=dtype)
# action = torch.as_tensor(list(action.values()))
obs_dict = self.capture_observation()
action_dict = {}
action_dict["action"] = action
return obs_dict, action_dict
def get_state(self) -> dict:
# neck roll, pitch, yaw
# r_shoulder_pitch, r_shoulder_roll, r_elbow_yaw, r_elbow_pitch, r_wrist_roll, r_wrist_pitch, r_wrist_yaw, r_gripper
# l_shoulder_pitch, l_shoulder_roll, l_elbow_yaw, l_elbow_pitch, l_wrist_roll, l_wrist_pitch, l_wrist_yaw, l_gripper
# mobile base x, y, theta
if self.is_connected:
if self.mobile_base_available:
odometry = self.reachy.mobile_base.odometry
else:
odometry = {"x": 0, "y": 0, "theta": 0, "vx": 0, "vy": 0, "vtheta": 0}
return {
"neck_yaw.pos": self.reachy.head.neck.yaw.present_position,
"neck_pitch.pos": self.reachy.head.neck.pitch.present_position,
"neck_roll.pos": self.reachy.head.neck.roll.present_position,
"r_shoulder_pitch.pos": self.reachy.r_arm.shoulder.pitch.present_position,
"r_shoulder_roll.pos": self.reachy.r_arm.shoulder.roll.present_position,
"r_elbow_yaw.pos": self.reachy.r_arm.elbow.yaw.present_position,
"r_elbow_pitch.pos": self.reachy.r_arm.elbow.pitch.present_position,
"r_wrist_roll.pos": self.reachy.r_arm.wrist.roll.present_position,
"r_wrist_pitch.pos": self.reachy.r_arm.wrist.pitch.present_position,
"r_wrist_yaw.pos": self.reachy.r_arm.wrist.yaw.present_position,
"r_gripper.pos": self.reachy.r_arm.gripper.present_position,
"l_shoulder_pitch.pos": self.reachy.l_arm.shoulder.pitch.present_position,
"l_shoulder_roll.pos": self.reachy.l_arm.shoulder.roll.present_position,
"l_elbow_yaw.pos": self.reachy.l_arm.elbow.yaw.present_position,
"l_elbow_pitch.pos": self.reachy.l_arm.elbow.pitch.present_position,
"l_wrist_roll.pos": self.reachy.l_arm.wrist.roll.present_position,
"l_wrist_pitch.pos": self.reachy.l_arm.wrist.pitch.present_position,
"l_wrist_yaw.pos": self.reachy.l_arm.wrist.yaw.present_position,
"l_gripper.pos": self.reachy.l_arm.gripper.present_position,
"mobile_base.vx": odometry["vx"],
"mobile_base.vy": odometry["vy"],
"mobile_base.vtheta": odometry["vtheta"],
}
else:
return {}
def capture_observation(self) -> dict:
if self.is_connected:
before_read_t = time.perf_counter()
state = self.get_state()
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
if self.state_keys is None:
self.state_keys = list(state)
dtype = self.motor_features["observation.state"]["dtype"]
state = np.array(list(state.values()), dtype=dtype)
# state = torch.as_tensor(list(state.values()))
# Capture images from cameras
images = {}
for name in self.cameras:
# before_camread_t = time.perf_counter()
images[name] = self.cameras[name].read() # Reachy cameras read() is not blocking?
# print(f'name: {name} img: {images[name]}')
if images[name] is not None:
# images[name] = copy(images[name][0]) # seems like I need to copy?
images[name] = torch.from_numpy(copy(images[name][0])) # seems like I need to copy?
self.logs[f"read_camera_{name}_dt_s"] = images[name][1] # full timestamp, TODO dt
# Populate output dictionnaries
obs_dict = {}
obs_dict["observation.state"] = state
for name in self.cameras:
obs_dict[f"observation.images.{name}"] = images[name]
return obs_dict
else:
return {}
def send_action(self, action: torch.Tensor) -> torch.Tensor:
if not self.is_connected:
raise ConnectionError()
self.reachy.head.neck.yaw.goal_position = float(action[0])
self.reachy.head.neck.pitch.goal_position = float(action[1])
self.reachy.head.neck.roll.goal_position = float(action[2])
self.reachy.r_arm.shoulder.pitch.goal_position = float(action[3])
self.reachy.r_arm.shoulder.roll.goal_position = float(action[4])
self.reachy.r_arm.elbow.yaw.goal_position = float(action[5])
self.reachy.r_arm.elbow.pitch.goal_position = float(action[6])
self.reachy.r_arm.wrist.roll.goal_position = float(action[7])
self.reachy.r_arm.wrist.roll.goal_position = float(action[8])
self.reachy.r_arm.wrist.yaw.goal_position = float(action[9])
self.reachy.r_arm.gripper.set_opening(float(action[10]))
self.reachy.l_arm.shoulder.pitch.goal_position = float(action[11])
self.reachy.l_arm.shoulder.roll.goal_position = float(action[12])
self.reachy.l_arm.elbow.yaw.goal_position = float(action[13])
self.reachy.l_arm.elbow.pitch.goal_position = float(action[14])
self.reachy.l_arm.wrist.roll.goal_position = float(action[15])
self.reachy.l_arm.wrist.roll.goal_position = float(action[16])
self.reachy.l_arm.wrist.yaw.goal_position = float(action[17])
self.reachy.l_arm.gripper.set_opening(float(action[18]))
s = time.time()
self.reachy.send_goal_positions(check_positions=False)
print("send_goal_positions", time.time() - s)
if self.mobile_base_available:
self.reachy.mobile_base.set_goal_speed(action[19], action[20], action[21])
self.reachy.mobile_base.send_speed_command()
# TODO: what shape is the action tensor?
# 7 dofs per arm (x2)
# 1 dof per gripper (x2)
# 3 dofs for the neck
# 3 dofs for the mobile base (x, y, theta)
# 7+7+1+1+3+3 = 22
return action
def print_logs(self) -> None:
pass
def disconnect(self) -> None:
print("Disconnecting")
self.is_connected = False
print("Turn off")
# self.reachy.turn_off_smoothly()
# self.reachy.turn_off()
print("\t turn off done")
self.reachy.disconnect()

View File

@@ -1,342 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The Pollen Robotics team and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import time
from copy import copy
from dataclasses import dataclass, field, replace
import numpy as np
import torch
from lerobot.common.robot_devices.cameras.reachy2 import ReachyCamera
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus
from reachy2_sdk import ReachySDK
REACHY_MOTORS = [
"neck_yaw.pos",
"neck_pitch.pos",
"neck_roll.pos",
"r_shoulder_pitch.pos",
"r_shoulder_roll.pos",
"r_elbow_yaw.pos",
"r_elbow_pitch.pos",
"r_wrist_roll.pos",
"r_wrist_pitch.pos",
"r_wrist_yaw.pos",
"r_gripper.pos",
"l_shoulder_pitch.pos",
"l_shoulder_roll.pos",
"l_elbow_yaw.pos",
"l_elbow_pitch.pos",
"l_wrist_roll.pos",
"l_wrist_pitch.pos",
"l_wrist_yaw.pos",
"l_gripper.pos",
"mobile_base.vx",
"mobile_base.vy",
"mobile_base.vtheta",
]
@dataclass
class ReachyManipulatorRobotConfig:
robot_type: str | None = "reachy2"
cameras: dict[str, ReachyCamera] = field(default_factory=lambda: {})
ip_address: str | None = "172.17.135.207"
# ip_address: str | None = "192.168.0.197"
# ip_address: str | None = "localhost"
class ReachyManipulatorRobot:
"""Wrapper of ReachySDK"""
def __init__(self, config: ReachyRobotManipulatorConfig | None = None, **kwargs):
if config is None:
config = ReachyRobotManipulatorConfig()
# Overwrite config arguments using kwargs
self.config = replace(config, **kwargs)
self.robot_type = self.config.robot_type
self.cameras = self.config.cameras
self.has_camera = True
self.num_cameras = len(self.cameras)
self.is_connected = False
self.teleop = None
self.logs = {}
self.reachy = None
self.mobile_base_available = False
self.state_keys = None
self.action_keys = None
self.leader_arm = FeetechMotorsBus(config.leader_arm.port, config.leader_arm.motors)
self.leader_calib_dir=config.leader_arm.calibration_dir
@property
def camera_features(self) -> dict:
cam_ft = {}
for cam_key, cam in self.cameras.items():
key = f"observation.images.{cam_key}"
cam_ft[key] = {
"shape": (cam.height, cam.width, cam.channels),
"names": ["height", "width", "channels"],
"info": None,
}
return cam_ft
@property
def motor_features(self) -> dict:
motors = REACHY_MOTORS
# if self.mobile_base_available:
# motors += REACHY_MOBILE_BASE
return {
"action": {
"dtype": "float32",
"shape": (len(motors),),
"names": motors,
},
"observation.state": {
"dtype": "float32",
"shape": (len(motors),),
"names": motors,
},
}
@property
def features(self):
return {**self.motor_features, **self.camera_features}
def connect(self) -> None:
self.reachy = ReachySDK(host=self.config.ip_address)
print("Connecting to Reachy")
self.reachy.connect()
self.is_connected = self.reachy.is_connected
if not self.is_connected:
print(
f"Cannot connect to Reachy at address {self.config.ip_address}. Maybe a connection already exists."
)
raise ConnectionError()
# self.reachy.turn_on()
print(self.cameras)
if self.cameras is not None:
for name in self.cameras:
print(f"Connecting camera: {name}")
self.cameras[name].connect()
self.is_connected = self.is_connected and self.cameras[name].is_connected
if not self.is_connected:
print("Could not connect to the cameras, check that all cameras are plugged-in.")
raise ConnectionError()
self.mobile_base_available = self.reachy.mobile_base is not None
print("Connecting to leader arm")
self.leader_arm.connect()
with open(self.leader_arm.calibration_dir) as f:
self.leader_arm.calibration = json.load(f)
self.leader_arm.set_calibration(self.leader_arm.calibration)
self.leader_arm.apply_calibration()
def run_calibration(self):
pass
def teleop_step(
self, record_data=False
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
if not record_data:
return
#get leader arm
leader_pos = {}
for name in self.leader_arm:
before_lread_t = time.perf_counter()
leader_pos[name] = self.leader_arm[name].read("Present_Position")
leader_pos[name] = torch.from_numpy(leader_pos[name])
self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - before_lread_t
#TODO leader arm FK
#TODO senf task space poses
action = {}
action["neck_roll.pos"] = self.reachy.head.neck.roll.goal_position
action["neck_pitch.pos"] = self.reachy.head.neck.pitch.goal_position
action["neck_yaw.pos"] = self.reachy.head.neck.yaw.goal_position
action["r_shoulder_pitch.pos"] = self.reachy.r_arm.shoulder.pitch.goal_position
action["r_shoulder_roll.pos"] = self.reachy.r_arm.shoulder.roll.goal_position
action["r_elbow_yaw.pos"] = self.reachy.r_arm.elbow.yaw.goal_position
action["r_elbow_pitch.pos"] = self.reachy.r_arm.elbow.pitch.goal_position
action["r_wrist_roll.pos"] = self.reachy.r_arm.wrist.roll.goal_position
action["r_wrist_pitch.pos"] = self.reachy.r_arm.wrist.pitch.goal_position
action["r_wrist_yaw.pos"] = self.reachy.r_arm.wrist.yaw.goal_position
action["r_gripper.pos"] = self.reachy.r_arm.gripper.opening
action["l_shoulder_pitch.pos"] = self.reachy.l_arm.shoulder.pitch.goal_position
action["l_shoulder_roll.pos"] = self.reachy.l_arm.shoulder.roll.goal_position
action["l_elbow_yaw.pos"] = self.reachy.l_arm.elbow.yaw.goal_position
action["l_elbow_pitch.pos"] = self.reachy.l_arm.elbow.pitch.goal_position
action["l_wrist_roll.pos"] = self.reachy.l_arm.wrist.roll.goal_position
action["l_wrist_pitch.pos"] = self.reachy.l_arm.wrist.pitch.goal_position
action["l_wrist_yaw.pos"] = self.reachy.l_arm.wrist.yaw.goal_position
action["l_gripper.pos"] = self.reachy.l_arm.gripper.opening
if self.mobile_base_available:
last_cmd_vel = self.reachy.mobile_base.last_cmd_vel
action["mobile_base_x.vel"] = last_cmd_vel["x"]
action["mobile_base_y.vel"] = last_cmd_vel["y"]
action["mobile_base_theta.vel"] = last_cmd_vel["theta"]
else:
action["mobile_base_x.vel"] = 0
action["mobile_base_y.vel"] = 0
action["mobile_base_theta.vel"] = 0
dtype = self.motor_features["action"]["dtype"]
action = np.array(list(action.values()), dtype=dtype)
# action = torch.as_tensor(list(action.values()))
obs_dict = self.capture_observation()
action_dict = {}
action_dict["action"] = action
return obs_dict, action_dict
def get_state(self) -> dict:
# neck roll, pitch, yaw
# r_shoulder_pitch, r_shoulder_roll, r_elbow_yaw, r_elbow_pitch, r_wrist_roll, r_wrist_pitch, r_wrist_yaw, r_gripper
# l_shoulder_pitch, l_shoulder_roll, l_elbow_yaw, l_elbow_pitch, l_wrist_roll, l_wrist_pitch, l_wrist_yaw, l_gripper
# mobile base x, y, theta
if self.is_connected:
if self.mobile_base_available:
odometry = self.reachy.mobile_base.odometry
else:
odometry = {"x": 0, "y": 0, "theta": 0, "vx": 0, "vy": 0, "vtheta": 0}
return {
"neck_yaw.pos": self.reachy.head.neck.yaw.present_position,
"neck_pitch.pos": self.reachy.head.neck.pitch.present_position,
"neck_roll.pos": self.reachy.head.neck.roll.present_position,
"r_shoulder_pitch.pos": self.reachy.r_arm.shoulder.pitch.present_position,
"r_shoulder_roll.pos": self.reachy.r_arm.shoulder.roll.present_position,
"r_elbow_yaw.pos": self.reachy.r_arm.elbow.yaw.present_position,
"r_elbow_pitch.pos": self.reachy.r_arm.elbow.pitch.present_position,
"r_wrist_roll.pos": self.reachy.r_arm.wrist.roll.present_position,
"r_wrist_pitch.pos": self.reachy.r_arm.wrist.pitch.present_position,
"r_wrist_yaw.pos": self.reachy.r_arm.wrist.yaw.present_position,
"r_gripper.pos": self.reachy.r_arm.gripper.present_position,
"l_shoulder_pitch.pos": self.reachy.l_arm.shoulder.pitch.present_position,
"l_shoulder_roll.pos": self.reachy.l_arm.shoulder.roll.present_position,
"l_elbow_yaw.pos": self.reachy.l_arm.elbow.yaw.present_position,
"l_elbow_pitch.pos": self.reachy.l_arm.elbow.pitch.present_position,
"l_wrist_roll.pos": self.reachy.l_arm.wrist.roll.present_position,
"l_wrist_pitch.pos": self.reachy.l_arm.wrist.pitch.present_position,
"l_wrist_yaw.pos": self.reachy.l_arm.wrist.yaw.present_position,
"l_gripper.pos": self.reachy.l_arm.gripper.present_position,
"mobile_base.vx": odometry["vx"],
"mobile_base.vy": odometry["vy"],
"mobile_base.vtheta": odometry["vtheta"],
}
else:
return {}
def capture_observation(self) -> dict:
if self.is_connected:
before_read_t = time.perf_counter()
state = self.get_state()
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
if self.state_keys is None:
self.state_keys = list(state)
dtype = self.motor_features["observation.state"]["dtype"]
state = np.array(list(state.values()), dtype=dtype)
# state = torch.as_tensor(list(state.values()))
# Capture images from cameras
images = {}
for name in self.cameras:
# before_camread_t = time.perf_counter()
images[name] = self.cameras[name].read() # Reachy cameras read() is not blocking?
# print(f'name: {name} img: {images[name]}')
if images[name] is not None:
# images[name] = copy(images[name][0]) # seems like I need to copy?
images[name] = torch.from_numpy(copy(images[name][0])) # seems like I need to copy?
self.logs[f"read_camera_{name}_dt_s"] = images[name][1] # full timestamp, TODO dt
# Populate output dictionnaries
obs_dict = {}
obs_dict["observation.state"] = state
for name in self.cameras:
obs_dict[f"observation.images.{name}"] = images[name]
return obs_dict
else:
return {}
def send_action(self, action: torch.Tensor) -> torch.Tensor:
if not self.is_connected:
raise ConnectionError()
self.reachy.head.neck.yaw.goal_position = float(action[0])
self.reachy.head.neck.pitch.goal_position = float(action[1])
self.reachy.head.neck.roll.goal_position = float(action[2])
self.reachy.r_arm.shoulder.pitch.goal_position = float(action[3])
self.reachy.r_arm.shoulder.roll.goal_position = float(action[4])
self.reachy.r_arm.elbow.yaw.goal_position = float(action[5])
self.reachy.r_arm.elbow.pitch.goal_position = float(action[6])
self.reachy.r_arm.wrist.roll.goal_position = float(action[7])
self.reachy.r_arm.wrist.roll.goal_position = float(action[8])
self.reachy.r_arm.wrist.yaw.goal_position = float(action[9])
self.reachy.r_arm.gripper.set_opening(float(action[10]))
self.reachy.l_arm.shoulder.pitch.goal_position = float(action[11])
self.reachy.l_arm.shoulder.roll.goal_position = float(action[12])
self.reachy.l_arm.elbow.yaw.goal_position = float(action[13])
self.reachy.l_arm.elbow.pitch.goal_position = float(action[14])
self.reachy.l_arm.wrist.roll.goal_position = float(action[15])
self.reachy.l_arm.wrist.roll.goal_position = float(action[16])
self.reachy.l_arm.wrist.yaw.goal_position = float(action[17])
self.reachy.l_arm.gripper.set_opening(float(action[18]))
s = time.time()
self.reachy.send_goal_positions(check_positions=False)
print("send_goal_positions", time.time() - s)
if self.mobile_base_available:
self.reachy.mobile_base.set_goal_speed(action[19], action[20], action[21])
self.reachy.mobile_base.send_speed_command()
# TODO: what shape is the action tensor?
# 7 dofs per arm (x2)
# 1 dof per gripper (x2)
# 3 dofs for the neck
# 3 dofs for the mobile base (x, y, theta)
# 7+7+1+1+3+3 = 22
return action
def print_logs(self) -> None:
pass
def disconnect(self) -> None:
print("Disconnecting")
self.is_connected = False
print("Turn off")
# self.reachy.turn_off_smoothly()
# self.reachy.turn_off()
print("\t turn off done")
self.reachy.disconnect()

View File

@@ -1,50 +0,0 @@
# [Reachy2 from Pollen Robotics](https://www.pollen-robotics.com)
# Requires installing extras packages
# With pip: `pip install -e ".[reachy2]"`
# With poetry: `poetry install --sync --extras "reachy2"`
_target_: lerobot.common.robot_devices.robots.reachy2.ReachyRobot
robot_type: reachy2
cameras:
head_left:
_target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
name: teleop
host: 172.17.134.85
# host: 192.168.0.197
# host: localhost
port: 50065
fps: 30
width: 960
height: 720
image_type: left
# head_right:
# _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
# name: teleop
# host: 172.17.135.207
# port: 50065
# image_type: right
# fps: 30
# width: 960
# height: 720
# torso_rgb:
# _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
# name: depth
# host: 172.17.135.207
# # host: localhost
# port: 50065
# image_type: rgb
# fps: 30
# width: 1280
# height: 720
# torso_depth:
# _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
# name: depth
# host: 172.17.135.207
# port: 50065
# image_type: depth
# fps: 30
# width: 1280
# height: 720

View File

@@ -1,65 +0,0 @@
# [Reachy2 from Pollen Robotics](https://www.pollen-robotics.com)
# Requires installing extras packages
# With pip: `pip install -e ".[reachy2]"`
# With poetry: `poetry install --sync --extras "reachy2"`
_target_: lerobot.common.robot_devices.robots.reachy2.ReachyRobotManipulatorconfig
robot_type: reachy2
leader_arm:
main:
_target_: lerobot.common.robot_devices.motors.feetech.FeetechMotorsBus
port: /dev/tty.usbmodem585A0077581
calibration_dir: .cache/calibration/so100
motors:
# name: (index, model)
shoulder_pan: [1, "sts3215"]
shoulder_lift: [2, "sts3215"]
elbow_flex: [3, "sts3215"]
wrist_flex: [4, "sts3215"]
wrist_roll: [5, "sts3215"]
gripper: [6, "sts3215"]
cameras:
head_left:
_target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
name: teleop
host: 172.17.134.85
# host: 192.168.0.197
# host: localhost
port: 50065
fps: 30
width: 960
height: 720
image_type: left
# head_right:
# _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
# name: teleop
# host: 172.17.135.207
# port: 50065
# image_type: right
# fps: 30
# width: 960
# height: 720
# torso_rgb:
# _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
# name: depth
# host: 172.17.135.207
# # host: localhost
# port: 50065
# image_type: rgb
# fps: 30
# width: 1280
# height: 720
# torso_depth:
# _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
# name: depth
# host: 172.17.135.207
# port: 50065
# image_type: depth
# fps: 30
# width: 1280
# height: 720

View File

@@ -191,7 +191,7 @@ def teleoperate(
@safe_disconnect
def record(
robot: Robot,
root: Path,
root: str,
repo_id: str,
single_task: str,
pretrained_policy_name_or_path: str | None = None,
@@ -204,7 +204,6 @@ def record(
video: bool = True,
run_compute_stats: bool = True,
push_to_hub: bool = True,
tags: list[str] | None = None,
num_image_writer_processes: int = 0,
num_image_writer_threads_per_camera: int = 4,
display_cameras: bool = True,
@@ -332,7 +331,7 @@ def record(
dataset.consolidate(run_compute_stats)
if push_to_hub:
dataset.push_to_hub(tags=tags)
dataset.push_to_hub()
log_say("Exiting", play_sounds)
return dataset
@@ -428,7 +427,7 @@ if __name__ == "__main__":
parser_record.add_argument(
"--root",
type=Path,
default=None,
default="data",
help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
)
parser_record.add_argument(
@@ -437,12 +436,6 @@ if __name__ == "__main__":
default="lerobot/test",
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
)
parser_record.add_argument(
"--resume",
type=int,
default=0,
help="Resume recording on an existing dataset.",
)
parser_record.add_argument(
"--warmup-time-s",
type=int,
@@ -501,6 +494,12 @@ if __name__ == "__main__":
"Not enough threads might cause low camera fps."
),
)
parser_record.add_argument(
"--force-override",
type=int,
default=0,
help="By default, data recording is resumed. When set to 1, delete the local directory and start data recording from scratch.",
)
parser_record.add_argument(
"-p",
"--pretrained-policy-name-or-path",
@@ -524,7 +523,7 @@ if __name__ == "__main__":
parser_replay.add_argument(
"--root",
type=Path,
default=None,
default="data",
help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
)
parser_replay.add_argument(

View File

@@ -1,21 +0,0 @@
import time
# from safetensors.torch import load_file, save_file
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.utils.utils import init_hydra_config, init_logging
if __name__ == "__main__":
init_logging()
control_mode = "test"
robot_path = "lerobot/configs/robot/reachy2.yaml"
robot_overrides = None
robot_cfg = init_hydra_config(robot_path, robot_overrides)
robot = make_robot(robot_cfg)
print(robot.is_connected)
# print(robot.get_state())
print(robot.capture_observation())
time.sleep(5)
robot.disconnect()

View File

@@ -268,11 +268,10 @@ def main():
args = parser.parse_args()
kwargs = vars(args)
repo_id = kwargs.pop("repo_id")
# root = kwargs.pop("root")
root = kwargs.pop("root")
logging.info("Loading dataset")
dataset = LeRobotDataset(repo_id)
dataset = LeRobotDataset(repo_id, root=root, local_files_only=True)
visualize_dataset(dataset, **vars(args))

View File

@@ -55,30 +55,13 @@ python lerobot/scripts/visualize_dataset_html.py \
import argparse
import logging
import shutil
import warnings
from pathlib import Path
import torch
import tqdm
from flask import Flask, redirect, render_template, url_for
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils.utils import init_hydra_config, init_logging
from lerobot.scripts.eval import get_pretrained_policy_path
class EpisodeSampler(torch.utils.data.Sampler):
def __init__(self, dataset, episode_index):
from_idx = dataset.episode_data_index["from"][episode_index].item()
to_idx = dataset.episode_data_index["to"][episode_index].item()
self.frame_ids = range(from_idx, to_idx)
def __iter__(self):
return iter(self.frame_ids)
def __len__(self):
return len(self.frame_ids)
from lerobot.common.utils.utils import init_logging
def run_server(
@@ -88,7 +71,6 @@ def run_server(
port: str,
static_folder: Path,
template_folder: Path,
has_policy=False,
):
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
@@ -131,101 +113,20 @@ def run_server(
dataset_info=dataset_info,
videos_info=videos_info,
ep_csv_url=ep_csv_url,
has_policy=has_policy,
has_policy=False,
)
app.run(host=host, port=port)
def run_inference(
dataset, episode_index, policy, policy_method="select_action", num_workers=4, batch_size=32, device="mps"
):
if policy_method not in ["select_action", "forward"]:
raise ValueError(
f"`policy_method` is expected to be 'select_action' or 'forward', but '{policy_method}' is provided instead."
)
policy.eval()
policy.to(device)
logging.info("Loading dataloader")
episode_sampler = EpisodeSampler(dataset, episode_index)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
# When using `select_action`, we set batch size 1 so that we feed 1 frame at a time, in a continuous fashion.
batch_size=1 if policy_method == "select_action" else batch_size,
sampler=episode_sampler,
drop_last=False,
)
warned_ndim_eq_0 = False
warned_ndim_gt_2 = False
logging.info("Running inference")
inference_results = {}
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
with torch.inference_mode():
if policy_method == "select_action":
gt_action = batch.pop("action")
output_dict = {"action": policy.select_action(batch)}
batch["action"] = gt_action
elif policy_method == "forward":
output_dict = policy.forward(batch)
# TODO(rcadene): Save and display all predicted actions at a given timestamp
# Save predicted action for the next timestamp only
output_dict["action"] = output_dict["action"][:, 0, :]
for key in output_dict:
if output_dict[key].ndim == 0:
if not warned_ndim_eq_0:
warnings.warn(
f"Ignore output key '{key}'. Its value is a scalar instead of a vector. It might have been aggregated over the batch dimension (e.g. `loss.mean()`).",
stacklevel=1,
)
warned_ndim_eq_0 = True
continue
if output_dict[key].ndim > 2:
if not warned_ndim_gt_2:
warnings.warn(
f"Ignore output key '{key}'. Its value is a tensor of {output_dict[key].ndim} dimensions instead of a vector.",
stacklevel=1,
)
warned_ndim_gt_2 = True
continue
if key not in inference_results:
inference_results[key] = []
inference_results[key].append(output_dict[key].to("cpu"))
for key in inference_results:
inference_results[key] = torch.cat(inference_results[key])
return inference_results
def get_ep_csv_fname(episode_id: int):
ep_csv_fname = f"episode_{episode_id}.csv"
return ep_csv_fname
def write_episode_data_csv(output_dir, file_name, episode_index, dataset, policy=None):
def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
"""Write a csv file containg timeseries data of an episode (e.g. state and action).
This file will be loaded by Dygraph javascript to plot data in real time."""
if policy is not None:
inference_results = run_inference(
dataset,
episode_index,
policy,
policy_method="select_action",
# num_workers=hydra_cfg.training.num_workers,
# batch_size=hydra_cfg.training.batch_size,
# device=hydra_cfg.device,
)
from_idx = dataset.episode_data_index["from"][episode_index]
to_idx = dataset.episode_data_index["to"][episode_index]
@@ -240,26 +141,21 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset, policy
if has_action:
dim_action = dataset.meta.shapes["action"][0]
header += [f"action_{i}" for i in range(dim_action)]
if policy is not None:
dim_action = dataset.meta.shapes["action"][0]
header += [f"pred_action_{i}" for i in range(dim_action)]
columns = ["timestamp"]
if has_state:
columns += ["observation.state"]
if has_action:
columns += ["action"]
data = dataset.hf_dataset.select_columns(columns)
rows = []
data = dataset.hf_dataset.select_columns(columns)
for i in range(from_idx, to_idx):
row = [data[i]["timestamp"].item()]
if has_state:
row += data[i]["observation.state"].tolist()
if has_action:
row += data[i]["action"].tolist()
if policy is not None:
row += inference_results["action"][i].tolist()
rows.append(row)
output_dir.mkdir(parents=True, exist_ok=True)
@@ -287,9 +183,6 @@ def visualize_dataset_html(
host: str = "127.0.0.1",
port: int = 9090,
force_override: bool = False,
policy_method: str = "select_action",
pretrained_policy_name_or_path: str | None = None,
policy_overrides: list[str] | None = None,
) -> Path | None:
init_logging()
@@ -321,31 +214,15 @@ def visualize_dataset_html(
if episodes is None:
episodes = list(range(dataset.num_episodes))
pretrained_policy_name_or_path = "aliberts/act_reachy_test_model"
policy = None
if pretrained_policy_name_or_path is not None:
logging.info("Loading policy")
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", overrides=["device=mps"])
# dataset = make_dataset(hydra_cfg)
policy = make_policy(hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
if policy_method == "select_action":
# Do not load previous observations or future actions, to simulate that the observations come from
# an environment.
dataset.delta_timestamps = None
logging.info("Writing CSV files")
for episode_index in tqdm.tqdm(episodes):
# write states and actions in a csv (it can be slow for big datasets)
ep_csv_fname = get_ep_csv_fname(episode_index)
# TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors?
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset, policy=policy)
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset)
if serve:
run_server(dataset, episodes, host, port, static_dir, template_dir, has_policy=policy is not None)
run_server(dataset, episodes, host, port, static_dir, template_dir)
def main():
@@ -404,8 +281,8 @@ def main():
args = parser.parse_args()
kwargs = vars(args)
repo_id = kwargs.pop("repo_id")
kwargs.pop("root")
dataset = LeRobotDataset(repo_id)
root = kwargs.pop("root")
dataset = LeRobotDataset(repo_id, root=root, local_files_only=True)
visualize_dataset_html(dataset, **kwargs)

View File

@@ -229,8 +229,7 @@
dygraph: null,
currentFrameData: null,
columnNames: ["state", "action", "pred action"],
hasPolicy: {% if has_policy %}true{% else %}false{% endif %},
nColumns: {% if has_policy %}3{% else %}2{% endif %},
nColumns: 2,
nStates: 0,
nActions: 0,
checked: [],
@@ -279,9 +278,6 @@
const seriesNames = this.dygraph.getLabels().slice(1);
this.nStates = seriesNames.findIndex(item => item.startsWith('action_'));
this.nActions = seriesNames.length - this.nStates;
if(this.hasPolicy){
this.nActions = Math.floor(this.nActions / 2);
}
const colors = [];
const LIGHTNESS = [30, 65, 85]; // state_lightness, action_lightness, pred_action_lightness
// colors for "state" lines
@@ -294,13 +290,6 @@
const color = `hsl(${hue}, 100%, ${LIGHTNESS[1]}%)`;
colors.push(color);
}
if(this.hasPolicy){
// colors for "action" lines
for (let hue = 0; hue < 360; hue += parseInt(360/this.nActions)) {
const color = `hsl(${hue}, 100%, ${LIGHTNESS[2]}%)`;
colors.push(color);
}
}
this.dygraph.updateOptions({ colors });
this.colors = colors;
@@ -338,10 +327,6 @@
// row consists of [state value, action value]
row.push(rowIndex < this.nStates ? this.currentFrameData[stateValueIdx] : nullCell); // push "state value" to row
row.push(rowIndex < this.nActions ? this.currentFrameData[actionValueIdx] : nullCell); // push "action value" to row
if(this.hasPolicy){
const predActionValueIdx = stateValueIdx + this.nStates + this.nActions; // because this.currentFrameData = [state0, state1, ..., stateN, action0, action1, ..., actionN, pred_action1, ..., pred_actionN]
row.push(rowIndex < this.nActions ? this.currentFrameData[predActionValueIdx] : nullCell); // push "action value" to row
}
rowIndex += 1;
rows.push(row);
}

2770
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -70,7 +70,6 @@ pyrealsense2 = {version = ">=2.55.1.6486", markers = "sys_platform != 'darwin'",
pyrender = {git = "https://github.com/mmatl/pyrender.git", markers = "sys_platform == 'linux'", optional = true}
hello-robot-stretch-body = {version = ">=0.7.27", markers = "sys_platform == 'linux'", optional = true}
pyserial = {version = ">=3.5", optional = true}
reachy2-sdk = {git = "https://github.com/pollen-robotics/reachy2-sdk", branch="450-opencv-dependency-version", optional = true}
jsonlines = ">=4.0.0"
@@ -87,7 +86,6 @@ dynamixel = ["dynamixel-sdk", "pynput"]
feetech = ["feetech-servo-sdk", "pynput"]
intelrealsense = ["pyrealsense2"]
stretch = ["hello-robot-stretch-body", "pyrender", "pyrealsense2", "pynput"]
reachy2 = ["reachy2-sdk"]
[tool.ruff]
line-length = 110