forked from tangger/lerobot
cleaning
This commit is contained in:
@@ -225,10 +225,7 @@ def load_episodes(local_dir: Path) -> dict:
|
|||||||
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
|
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
|
||||||
# We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
|
# We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
|
||||||
# is a dictionary of stats and not an integer.
|
# is a dictionary of stats and not an integer.
|
||||||
episode_stats = {
|
episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
|
||||||
"episode_index": episode_index,
|
|
||||||
"stats": serialize_dict(episode_stats),
|
|
||||||
}
|
|
||||||
append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH)
|
append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH)
|
||||||
|
|
||||||
|
|
||||||
@@ -412,7 +409,7 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
|||||||
|
|
||||||
names = ft["names"]
|
names = ft["names"]
|
||||||
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
||||||
if names is not None and names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||||
shape = (shape[2], shape[0], shape[1])
|
shape = (shape[2], shape[0], shape[1])
|
||||||
elif key == "observation.environment_state":
|
elif key == "observation.environment_state":
|
||||||
type = FeatureType.ENV
|
type = FeatureType.ENV
|
||||||
@@ -543,10 +540,7 @@ def check_timestamps_sync(
|
|||||||
|
|
||||||
|
|
||||||
def check_delta_timestamps(
|
def check_delta_timestamps(
|
||||||
delta_timestamps: dict[str, list[float]],
|
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
|
||||||
fps: int,
|
|
||||||
tolerance_s: float,
|
|
||||||
raise_value_error: bool = True,
|
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance.
|
"""This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance.
|
||||||
This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be
|
This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be
|
||||||
|
|||||||
@@ -79,46 +79,28 @@ def create_stats_buffers(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
|
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
|
||||||
if stats and key in stats:
|
if stats:
|
||||||
if norm_mode is NormalizationMode.MEAN_STD:
|
if isinstance(stats[key]["mean"], np.ndarray):
|
||||||
if "mean" not in stats[key] or "std" not in stats[key]:
|
if norm_mode is NormalizationMode.MEAN_STD:
|
||||||
raise ValueError(
|
|
||||||
f"Missing 'mean' or 'std' in stats for key {key} with MEAN_STD normalization"
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(stats[key]["mean"], np.ndarray):
|
|
||||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
||||||
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
||||||
elif isinstance(stats[key]["mean"], torch.Tensor):
|
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
|
||||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
|
||||||
# unnormalization). See the logic here
|
|
||||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
|
||||||
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
|
|
||||||
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
|
||||||
else:
|
|
||||||
type_ = type(stats[key]["mean"])
|
|
||||||
raise ValueError(
|
|
||||||
f"np.ndarray or torch.Tensor expected for 'mean', but type is '{type_}' instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
|
||||||
if "min" not in stats[key] or "max" not in stats[key]:
|
|
||||||
raise ValueError(
|
|
||||||
f"Missing 'min' or 'max' in stats for key {key} with MIN_MAX normalization"
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(stats[key]["min"], np.ndarray):
|
|
||||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
||||||
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
||||||
elif isinstance(stats[key]["min"], torch.Tensor):
|
elif isinstance(stats[key]["mean"], torch.Tensor):
|
||||||
|
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||||
|
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||||
|
# unnormalization). See the logic here
|
||||||
|
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||||
|
if norm_mode is NormalizationMode.MEAN_STD:
|
||||||
|
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
|
||||||
|
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
||||||
|
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||||
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
|
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
|
||||||
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
||||||
else:
|
else:
|
||||||
type_ = type(stats[key]["min"])
|
type_ = type(stats[key]["mean"])
|
||||||
raise ValueError(
|
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
|
||||||
f"np.ndarray or torch.Tensor expected for 'min', but type is '{type_}' instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
stats_buffers[key] = buffer
|
stats_buffers[key] = buffer
|
||||||
return stats_buffers
|
return stats_buffers
|
||||||
@@ -166,14 +148,11 @@ class Normalize(nn.Module):
|
|||||||
for key, buffer in stats_buffers.items():
|
for key, buffer in stats_buffers.items():
|
||||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||||
|
|
||||||
# TODO(rcadene): should we remove torch.no_grad?
|
|
||||||
# @torch.no_grad
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||||
for key, ft in self.features.items():
|
for key, ft in self.features.items():
|
||||||
if key not in batch:
|
if key not in batch:
|
||||||
# FIXME(aliberts, rcadene): This might lead to silent fail!
|
# FIXME(aliberts, rcadene): This might lead to silent fail!
|
||||||
# NOTE: (azouitine) This continues help us for instantiation SACPolicy
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||||
@@ -241,8 +220,6 @@ class Unnormalize(nn.Module):
|
|||||||
for key, buffer in stats_buffers.items():
|
for key, buffer in stats_buffers.items():
|
||||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||||
|
|
||||||
# TODO(rcadene): should we remove torch.no_grad?
|
|
||||||
# @torch.no_grad
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||||
for key, ft in self.features.items():
|
for key, ft in self.features.items():
|
||||||
|
|||||||
@@ -28,22 +28,14 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
||||||
from lerobot.common.robot_devices.motors.utils import (
|
from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs
|
||||||
MotorsBus,
|
|
||||||
make_motors_buses_from_configs,
|
|
||||||
)
|
|
||||||
from lerobot.common.robot_devices.robots.configs import ManipulatorRobotConfig
|
from lerobot.common.robot_devices.robots.configs import ManipulatorRobotConfig
|
||||||
from lerobot.common.robot_devices.robots.utils import get_arm_id
|
from lerobot.common.robot_devices.robots.utils import get_arm_id
|
||||||
from lerobot.common.robot_devices.utils import (
|
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||||
RobotDeviceAlreadyConnectedError,
|
|
||||||
RobotDeviceNotConnectedError,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def ensure_safe_goal_position(
|
def ensure_safe_goal_position(
|
||||||
goal_pos: torch.Tensor,
|
goal_pos: torch.Tensor, present_pos: torch.Tensor, max_relative_target: float | list[float]
|
||||||
present_pos: torch.Tensor,
|
|
||||||
max_relative_target: float | list[float],
|
|
||||||
):
|
):
|
||||||
# Cap relative action target magnitude for safety.
|
# Cap relative action target magnitude for safety.
|
||||||
diff = goal_pos - present_pos
|
diff = goal_pos - present_pos
|
||||||
@@ -53,7 +45,7 @@ def ensure_safe_goal_position(
|
|||||||
safe_goal_pos = present_pos + safe_diff
|
safe_goal_pos = present_pos + safe_diff
|
||||||
|
|
||||||
if not torch.allclose(goal_pos, safe_goal_pos):
|
if not torch.allclose(goal_pos, safe_goal_pos):
|
||||||
logging.debug(
|
logging.warning(
|
||||||
"Relative goal position magnitude had to be clamped to be safe.\n"
|
"Relative goal position magnitude had to be clamped to be safe.\n"
|
||||||
f" requested relative goal position target: {diff}\n"
|
f" requested relative goal position target: {diff}\n"
|
||||||
f" clamped relative goal position target: {safe_diff}"
|
f" clamped relative goal position target: {safe_diff}"
|
||||||
@@ -317,9 +309,7 @@ class ManipulatorRobot:
|
|||||||
print(f"Missing calibration file '{arm_calib_path}'")
|
print(f"Missing calibration file '{arm_calib_path}'")
|
||||||
|
|
||||||
if self.robot_type in ["koch", "koch_bimanual", "aloha"]:
|
if self.robot_type in ["koch", "koch_bimanual", "aloha"]:
|
||||||
from lerobot.common.robot_devices.robots.dynamixel_calibration import (
|
from lerobot.common.robot_devices.robots.dynamixel_calibration import run_arm_calibration
|
||||||
run_arm_calibration,
|
|
||||||
)
|
|
||||||
|
|
||||||
calibration = run_arm_calibration(arm, self.robot_type, name, arm_type)
|
calibration = run_arm_calibration(arm, self.robot_type, name, arm_type)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user