[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-04 13:38:47 +00:00
committed by Michel Aractingi
parent bb69cb3c8c
commit 85fe8a3f4e
79 changed files with 2800 additions and 794 deletions

View File

@@ -108,20 +108,26 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
break
if motor_index == -1:
raise ValueError("No motors detected. Please ensure you have one motor connected.")
raise ValueError(
"No motors detected. Please ensure you have one motor connected."
)
print(f"Motor index found at: {motor_index}")
if brand == "feetech":
# Allows ID and BAUDRATE to be written in memory
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0)
motor_bus.write_with_motor_ids(
motor_bus.motor_models, motor_index, "Lock", 0
)
if baudrate != baudrate_des:
print(f"Setting its baudrate to {baudrate_des}")
baudrate_idx = list(series_baudrate_table.values()).index(baudrate_des)
# The write can fail, so we allow retries
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx)
motor_bus.write_with_motor_ids(
motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx
)
time.sleep(0.5)
motor_bus.set_bus_baudrate(baudrate_des)
present_baudrate_idx = motor_bus.read_with_motor_ids(
@@ -136,7 +142,9 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0)
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "ID", motor_idx_des)
present_idx = motor_bus.read_with_motor_ids(motor_bus.motor_models, motor_idx_des, "ID", num_retry=2)
present_idx = motor_bus.read_with_motor_ids(
motor_bus.motor_models, motor_idx_des, "ID", num_retry=2
)
if present_idx != motor_idx_des:
raise OSError("Failed to write index.")
@@ -164,12 +172,29 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=str, required=True, help="Motors bus port (e.g. dynamixel,feetech)")
parser.add_argument("--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)")
parser.add_argument("--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)")
parser.add_argument("--ID", type=int, required=True, help="Desired ID of the current motor (e.g. 1,2,3)")
parser.add_argument(
"--baudrate", type=int, default=1000000, help="Desired baudrate for the motor (default: 1000000)"
"--port",
type=str,
required=True,
help="Motors bus port (e.g. dynamixel,feetech)",
)
parser.add_argument(
"--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)"
)
parser.add_argument(
"--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)"
)
parser.add_argument(
"--ID",
type=int,
required=True,
help="Desired ID of the current motor (e.g. 1,2,3)",
)
parser.add_argument(
"--baudrate",
type=int,
default=1000000,
help="Desired baudrate for the motor (default: 1000000)",
)
args = parser.parse_args()

View File

@@ -149,7 +149,11 @@ def init_sim_calibration(robot, cfg):
axis_directions = np.array(cfg.get("axis_directions", [1]))
offsets = np.array(cfg.get("offsets", [0])) * np.pi
return {"start_pos": start_pos, "axis_directions": axis_directions, "offsets": offsets}
return {
"start_pos": start_pos,
"axis_directions": axis_directions,
"offsets": offsets,
}
def real_positions_to_sim(real_positions, axis_directions, start_pos, offsets):
@@ -170,7 +174,10 @@ def teleoperate(env, robot: Robot, process_action_fn, teleop_time_s=None):
leader_pos = robot.leader_arms.main.read("Present_Position")
action = process_action_fn(leader_pos)
env.step(np.expand_dims(action, 0))
if teleop_time_s is not None and time.perf_counter() - start_teleop_t > teleop_time_s:
if (
teleop_time_s is not None
and time.perf_counter() - start_teleop_t > teleop_time_s
):
print("Teleoperation processes finished.")
break
@@ -202,19 +209,27 @@ def record(
# Load pretrained policy
extra_features = (
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}}
if assign_rewards
else None
)
policy = None
if pretrained_policy_name_or_path is not None:
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
policy, policy_fps, device, use_amp = init_policy(
pretrained_policy_name_or_path, policy_overrides
)
if fps is None:
fps = policy_fps
logging.warning(f"No fps provided, so using the fps from policy config ({policy_fps}).")
logging.warning(
f"No fps provided, so using the fps from policy config ({policy_fps})."
)
if policy is None and process_action_from_leader is None:
raise ValueError("Either policy or process_action_fn has to be set to enable control in sim.")
raise ValueError(
"Either policy or process_action_fn has to be set to enable control in sim."
)
# initialize listener before sim env
listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
@@ -256,7 +271,11 @@ def record(
"shape": env.observation_space[obs_key].shape,
}
features["action"] = {"dtype": "float32", "shape": env.action_space.shape, "names": None}
features["action"] = {
"dtype": "float32",
"shape": env.action_space.shape,
"names": None,
}
features = {**features, **extra_features}
# Create empty dataset or load existing saved episodes
@@ -357,7 +376,9 @@ def record(
if events["stop_recording"] or recorded_episodes >= num_episodes:
break
else:
logging.info("Waiting for a few seconds before starting next episode recording...")
logging.info(
"Waiting for a few seconds before starting next episode recording..."
)
busy_wait(3)
log_say("Stop recording", play_sounds, blocking=True)
@@ -375,7 +396,12 @@ def record(
def replay(
env, root: Path, repo_id: str, episode: int, fps: int | None = None, local_files_only: bool = True
env,
root: Path,
repo_id: str,
episode: int,
fps: int | None = None,
local_files_only: bool = True,
):
env = env()
@@ -422,7 +448,10 @@ if __name__ == "__main__":
parser_record = subparsers.add_parser("record", parents=[base_parser])
parser_record.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
"--fps",
type=none_or_int,
default=None,
help="Frames per second (set to None to disable)",
)
parser_record.add_argument(
"--root",
@@ -448,7 +477,9 @@ if __name__ == "__main__":
required=True,
help="A description of the task preformed during recording that can be used as a language instruction.",
)
parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.")
parser_record.add_argument(
"--num-episodes", type=int, default=50, help="Number of episodes to record."
)
parser_record.add_argument(
"--run-compute-stats",
type=int,
@@ -509,7 +540,10 @@ if __name__ == "__main__":
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
parser_replay.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
"--fps",
type=none_or_int,
default=None,
help="Frames per second (set to None to disable)",
)
parser_replay.add_argument(
"--root",
@@ -523,7 +557,9 @@ if __name__ == "__main__":
default="lerobot/test",
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
)
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episodes to replay.")
parser_replay.add_argument(
"--episode", type=int, default=0, help="Index of the episodes to replay."
)
args = parser.parse_args()

View File

@@ -59,7 +59,11 @@ np_version = np.__version__ if HAS_NP else "N/A"
torch_version = torch.__version__ if HAS_TORCH else "N/A"
torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A"
cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A"
cuda_version = (
torch._C._cuda_getCompiledVersion()
if HAS_TORCH and torch.version.cuda is not None
else "N/A"
)
# TODO(aliberts): refactor into an actual command `lerobot env`
@@ -77,7 +81,9 @@ def display_sys_info() -> dict:
"Using GPU in script?": "<fill in>",
# "Using distributed or parallel set-up in script?": "<fill in>",
}
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n")
print(
"\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n"
)
print(format_dict(info))
return info

View File

@@ -174,7 +174,10 @@ def rollout(
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
# available of none of the envs finished.
if "final_info" in info:
successes = [info["is_success"] if info is not None else False for info in info["final_info"]]
successes = [
info["is_success"] if info is not None else False
for info in info["final_info"]
]
else:
successes = [False] * env.num_envs
@@ -188,9 +191,13 @@ def rollout(
step += 1
running_success_rate = (
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean()
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any")
.numpy()
.mean()
)
progbar.set_postfix(
{"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"}
)
progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"})
progbar.update()
# Track the final observation.
@@ -208,7 +215,9 @@ def rollout(
if return_observations:
stacked_observations = {}
for key in all_observations[0]:
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
stacked_observations[key] = torch.stack(
[obs[key] for obs in all_observations], dim=1
)
ret["observation"] = stacked_observations
if hasattr(policy, "use_original_modules"):
@@ -270,7 +279,9 @@ def eval_policy(
return
n_to_render_now = min(max_episodes_rendered - n_episodes_rendered, env.num_envs)
if isinstance(env, gym.vector.SyncVectorEnv):
ep_frames.append(np.stack([env.envs[i].render() for i in range(n_to_render_now)])) # noqa: B023
ep_frames.append(
np.stack([env.envs[i].render() for i in range(n_to_render_now)])
) # noqa: B023
elif isinstance(env, gym.vector.AsyncVectorEnv):
# Here we must render all frames and discard any we don't need.
ep_frames.append(np.stack(env.call("render")[:n_to_render_now]))
@@ -282,7 +293,9 @@ def eval_policy(
episode_data: dict | None = None
# we dont want progress bar when we use slurm, since it clutters the logs
progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm())
progbar = trange(
n_batches, desc="Stepping through eval batches", disable=inside_slurm()
)
for batch_ix in progbar:
# Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout
# step.
@@ -293,7 +306,8 @@ def eval_policy(
seeds = None
else:
seeds = range(
start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
start_seed + (batch_ix * env.num_envs),
start_seed + ((batch_ix + 1) * env.num_envs),
)
rollout_data = rollout(
env,
@@ -311,13 +325,22 @@ def eval_policy(
# Make a mask with shape (batch, n_steps) to mask out rollout data after the first done
# (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step.
mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int()
mask = (
torch.arange(n_steps)
<= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)
).int()
# Extend metrics.
batch_sum_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "sum")
batch_sum_rewards = einops.reduce(
(rollout_data["reward"] * mask), "b n -> b", "sum"
)
sum_rewards.extend(batch_sum_rewards.tolist())
batch_max_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "max")
batch_max_rewards = einops.reduce(
(rollout_data["reward"] * mask), "b n -> b", "max"
)
max_rewards.extend(batch_max_rewards.tolist())
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
batch_successes = einops.reduce(
(rollout_data["success"] * mask), "b n -> b", "any"
)
all_successes.extend(batch_successes.tolist())
if seeds:
all_seeds.extend(seeds)
@@ -330,17 +353,27 @@ def eval_policy(
rollout_data,
done_indices,
start_episode_index=batch_ix * env.num_envs,
start_data_index=(0 if episode_data is None else (episode_data["index"][-1].item() + 1)),
start_data_index=(
0
if episode_data is None
else (episode_data["index"][-1].item() + 1)
),
fps=env.unwrapped.metadata["render_fps"],
)
if episode_data is None:
episode_data = this_episode_data
else:
# Some sanity checks to make sure we are correctly compiling the data.
assert episode_data["episode_index"][-1] + 1 == this_episode_data["episode_index"][0]
assert (
episode_data["episode_index"][-1] + 1
== this_episode_data["episode_index"][0]
)
assert episode_data["index"][-1] + 1 == this_episode_data["index"][0]
# Concatenate the episode data.
episode_data = {k: torch.cat([episode_data[k], this_episode_data[k]]) for k in episode_data}
episode_data = {
k: torch.cat([episode_data[k], this_episode_data[k]])
for k in episode_data
}
# Maybe render video for visualization.
if max_episodes_rendered > 0 and len(ep_frames) > 0:
@@ -358,7 +391,9 @@ def eval_policy(
target=write_video,
args=(
str(video_path),
stacked_frames[: done_index + 1], # + 1 to capture the last observation
stacked_frames[
: done_index + 1
], # + 1 to capture the last observation
env.unwrapped.metadata["render_fps"],
),
)
@@ -367,7 +402,9 @@ def eval_policy(
n_episodes_rendered += 1
progbar.set_postfix(
{"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"}
{
"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"
}
)
# Wait till all video rendering threads are done.
@@ -413,7 +450,11 @@ def eval_policy(
def _compile_episode_data(
rollout_data: dict, done_indices: Tensor, start_episode_index: int, start_data_index: int, fps: float
rollout_data: dict,
done_indices: Tensor,
start_episode_index: int,
start_data_index: int,
fps: float,
) -> dict:
"""Convenience function for `eval_policy(return_episode_data=True)`
@@ -431,12 +472,16 @@ def _compile_episode_data(
# Here we do `num_frames - 1` as we don't want to include the last observation frame just yet.
ep_dict = {
"action": rollout_data["action"][ep_ix, : num_frames - 1],
"episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)),
"episode_index": torch.tensor(
[start_episode_index + ep_ix] * (num_frames - 1)
),
"frame_index": torch.arange(0, num_frames - 1, 1),
"timestamp": torch.arange(0, num_frames - 1, 1) / fps,
"next.done": rollout_data["done"][ep_ix, : num_frames - 1],
"next.success": rollout_data["success"][ep_ix, : num_frames - 1],
"next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32),
"next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(
torch.float32
),
}
# For the last observation frame, all other keys will just be copy padded.
@@ -452,7 +497,9 @@ def _compile_episode_data(
for key in ep_dicts[0]:
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1)
data_dict["index"] = torch.arange(
start_data_index, start_data_index + total_frames, 1
)
return data_dict

View File

@@ -46,7 +46,11 @@ import torch
from tqdm import trange
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position
from lerobot.common.robot_devices.control_utils import (
busy_wait,
is_headless,
reset_follower_position,
)
from lerobot.common.robot_devices.robots.factory import Robot, make_robot
from lerobot.common.utils.utils import (
init_hydra_config,
@@ -60,13 +64,19 @@ def get_classifier(pretrained_path, config_path):
return
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
ClassifierConfig,
)
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
Classifier,
)
cfg = init_hydra_config(config_path)
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
classifier_config.num_cameras = len(
cfg.training.image_keys
) # TODO automate these paths
model = Classifier(classifier_config)
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
model = model.to("mps")
@@ -151,11 +161,17 @@ def rollout(
images = []
for key in image_keys:
if display_cameras:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.imshow(
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
)
cv2.waitKey(1)
images.append(observation[key].to("mps"))
reward = reward_classifier.predict_reward(images) if reward_classifier is not None else 0.0
reward = (
reward_classifier.predict_reward(images)
if reward_classifier is not None
else 0.0
)
all_rewards.append(reward)
# print("REWARD : ", reward)
@@ -219,11 +235,19 @@ def eval_policy(
start_eval = time.perf_counter()
progbar = trange(n_episodes, desc="Evaluating policy on real robot")
reward_classifier = get_classifier(reward_classifier_pretrained_path, reward_classifier_config_file)
reward_classifier = get_classifier(
reward_classifier_pretrained_path, reward_classifier_config_file
)
for _ in progbar:
rollout_data = rollout(
robot, policy, reward_classifier, fps, control_time_s, use_amp, display_cameras
robot,
policy,
reward_classifier,
fps,
control_time_s,
use_amp,
display_cameras,
)
rollouts.append(rollout_data)
@@ -289,7 +313,9 @@ def init_keyboard_listener():
print("Right arrow key pressed. Exiting loop...")
events["exit_early"] = True
elif key == keyboard.Key.left:
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
print(
"Left arrow key pressed. Exiting loop and rerecord the last episode..."
)
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.space:
@@ -301,7 +327,10 @@ def init_keyboard_listener():
"Place the leader in similar pose to the follower and press space again."
)
events["pause_policy"] = True
log_say("Human intervention stage. Get ready to take over.", play_sounds=True)
log_say(
"Human intervention stage. Get ready to take over.",
play_sounds=True,
)
else:
events["human_intervention_step"] = True
print("Space key pressed. Human intervention starting.")
@@ -351,7 +380,9 @@ if __name__ == "__main__":
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
),
)
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
parser.add_argument(
"--revision", help="Optionally provide the Hugging Face Hub revision ID."
)
parser.add_argument(
"--out-dir",
help=(
@@ -360,7 +391,8 @@ if __name__ == "__main__":
),
)
parser.add_argument(
"--display-cameras", help=("Whether to display the camera feed while the rollout is happening")
"--display-cameras",
help=("Whether to display the camera feed while the rollout is happening"),
)
parser.add_argument(
"--reward-classifier-pretrained-path",

View File

@@ -45,9 +45,13 @@ def find_port():
print(f"The port of this MotorsBus is '{port}'")
print("Reconnect the USB cable.")
elif len(ports_diff) == 0:
raise OSError(f"Could not detect the port. No difference was found ({ports_diff}).")
raise OSError(
f"Could not detect the port. No difference was found ({ports_diff})."
)
else:
raise OSError(f"Could not detect the port. More than one port was found ({ports_diff}).")
raise OSError(
f"Could not detect the port. More than one port was found ({ports_diff})."
)
if __name__ == "__main__":

View File

@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import random
from typing import Any, Callable, Optional, Sequence, TypedDict
import io
@@ -737,7 +736,6 @@ def concatenate_batch_transitions(
if __name__ == "__main__":
import numpy as np
from tempfile import TemporaryDirectory
# ===== Test 1: Create and use a synthetic ReplayBuffer =====
@@ -1139,7 +1137,7 @@ if __name__ == "__main__":
savings_percent = (std_mem - opt_mem) / std_mem * 100
print(f"\nMemory optimization result:")
print("\nMemory optimization result:")
print(f"- Standard buffer state memory: {std_mem / (1024 * 1024):.2f} MB")
print(f"- Optimized buffer state memory: {opt_mem / (1024 * 1024):.2f} MB")
print(f"- Memory savings for state tensors: {savings_percent:.1f}%")

View File

@@ -225,7 +225,9 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Crop rectangular ROIs from a LeRobot dataset.")
parser = argparse.ArgumentParser(
description="Crop rectangular ROIs from a LeRobot dataset."
)
parser.add_argument(
"--repo-id",
type=str,
@@ -247,7 +249,9 @@ if __name__ == "__main__":
args = parser.parse_args()
local_files_only = args.root is not None
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root, local_files_only=local_files_only)
dataset = LeRobotDataset(
repo_id=args.repo_id, root=args.root, local_files_only=local_files_only
)
images = get_image_from_lerobot_dataset(dataset)
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
@@ -256,7 +260,7 @@ if __name__ == "__main__":
if args.crop_params_path is None:
rois = select_square_roi_for_images(images)
else:
with open(args.crop_params_path, "r") as f:
with open(args.crop_params_path) as f:
rois = json.load(f)
# rois = {

View File

@@ -31,7 +31,9 @@ def find_joint_bounds(
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.imshow(
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
)
cv2.waitKey(1)
timestamp = time.perf_counter() - start_episode_t
@@ -57,7 +59,12 @@ if __name__ == "__main__":
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds")
parser.add_argument(
"--control-time-s",
type=float,
default=20,
help="Maximum episode length in seconds",
)
args = parser.parse_args()
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)

View File

@@ -146,7 +146,7 @@ def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
def initialize_replay_buffer(
cfg: DictConfig, logger: Logger, device: str, storage_device:str
cfg: DictConfig, logger: Logger, device: str, storage_device: str
) -> ReplayBuffer:
if not cfg.resume:
return ReplayBuffer(

View File

@@ -10,7 +10,9 @@ from typing import Any
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, torch.Tensor]:
def preprocess_maniskill_observation(
observations: dict[str, np.ndarray],
) -> dict[str, torch.Tensor]:
"""Convert environment observation to LeRobot format observation.
Args:
observation: Dictionary of observation batches from a Gym vector environment.
@@ -62,7 +64,9 @@ class ManiSkillCompat(gym.Wrapper):
new_action_space_shape = env.action_space.shape[-1]
new_low = np.squeeze(env.action_space.low, axis=0)
new_high = np.squeeze(env.action_space.high, axis=0)
self.action_space = gym.spaces.Box(low=new_low, high=new_high, shape=(new_action_space_shape,))
self.action_space = gym.spaces.Box(
low=new_low, high=new_high, shape=(new_action_space_shape,)
)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
@@ -81,7 +85,9 @@ class ManiSkillCompat(gym.Wrapper):
class ManiSkillActionWrapper(gym.ActionWrapper):
def __init__(self, env):
super().__init__(env)
self.action_space = gym.spaces.Tuple(spaces=(env.action_space, gym.spaces.Discrete(2)))
self.action_space = gym.spaces.Tuple(
spaces=(env.action_space, gym.spaces.Discrete(2))
)
def action(self, action):
action, telop = action
@@ -95,7 +101,9 @@ class ManiSkillMultiplyActionWrapper(gym.Wrapper):
action_space_agent: gym.spaces.Box = env.action_space[0]
action_space_agent.low = action_space_agent.low * multiply_factor
action_space_agent.high = action_space_agent.high * multiply_factor
self.action_space = gym.spaces.Tuple(spaces=(action_space_agent, gym.spaces.Discrete(2)))
self.action_space = gym.spaces.Tuple(
spaces=(action_space_agent, gym.spaces.Discrete(2))
)
def step(self, action):
if isinstance(action, tuple):
@@ -137,7 +145,9 @@ def make_maniskill(
env = ManiSkillObservationWrapper(env, device=cfg.env.device)
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
env._max_episode_steps = env.max_episode_steps = (
50 # gym_utils.find_max_episode_steps_value(env)
)
env.unwrapped.metadata["render_fps"] = 20
env = ManiSkillCompat(env)
env = ManiSkillActionWrapper(env)
@@ -149,10 +159,11 @@ def make_maniskill(
if __name__ == "__main__":
import argparse
import hydra
from omegaconf import OmegaConf
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="lerobot/configs/env/maniskill_example.yaml")
parser.add_argument(
"--config", type=str, default="lerobot/configs/env/maniskill_example.yaml"
)
args = parser.parse_args()
# Initialize config

View File

@@ -73,7 +73,9 @@ def make_optimizer_and_scheduler(cfg, policy):
},
]
optimizer = torch.optim.AdamW(
optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
optimizer_params_dicts,
lr=cfg.training.lr,
weight_decay=cfg.training.weight_decay,
)
lr_scheduler = None
elif cfg.policy.name == "diffusion":
@@ -100,14 +102,23 @@ def make_optimizer_and_scheduler(cfg, policy):
optimizer = torch.optim.Adam(
[
{"params": policy.actor.parameters(), "lr": policy.config.actor_lr},
{"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr},
{"params": policy.temperature.parameters(), "lr": policy.config.temperature_lr},
{
"params": policy.critic_ensemble.parameters(),
"lr": policy.config.critic_lr,
},
{
"params": policy.temperature.parameters(),
"lr": policy.config.temperature_lr,
},
]
)
lr_scheduler = None
elif cfg.policy.name == "vqbet":
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler
from lerobot.common.policies.vqbet.modeling_vqbet import (
VQBeTOptimizer,
VQBeTScheduler,
)
optimizer = VQBeTOptimizer(policy, cfg)
lr_scheduler = VQBeTScheduler(optimizer, cfg)
@@ -214,7 +225,9 @@ def train(cfg: TrainPipelineConfig):
if cfg.resume:
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_learnable_params = sum(
p.numel() for p in policy.parameters() if p.requires_grad
)
num_total_params = sum(p.numel() for p in policy.parameters())
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")

View File

@@ -14,7 +14,6 @@
import logging
import time
from contextlib import nullcontext
from pathlib import Path
from pprint import pformat
import hydra
@@ -28,14 +27,16 @@ from termcolor import colored
from torch import optim
from torch.autograd import profiler
from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader, RandomSampler, WeightedRandomSampler, random_split
from torch.utils.data import DataLoader, RandomSampler, WeightedRandomSampler
from tqdm import tqdm
from lerobot.common.datasets.factory import resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
ClassifierConfig,
)
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
from lerobot.common.utils.utils import (
format_big_number,
@@ -50,7 +51,11 @@ def get_model(cfg, logger): # noqa I001
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
model = Classifier(classifier_config)
if cfg.resume:
model.load_state_dict(Classifier.from_pretrained(str(logger.last_pretrained_model_dir)).state_dict())
model.load_state_dict(
Classifier.from_pretrained(
str(logger.last_pretrained_model_dir)
).state_dict()
)
return model
@@ -62,7 +67,9 @@ def create_balanced_sampler(dataset, cfg):
class_weights = 1.0 / counts.float()
sample_weights = class_weights[labels]
return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
return WeightedRandomSampler(
weights=sample_weights, num_samples=len(sample_weights), replacement=True
)
def support_amp(device: torch.device, cfg: DictConfig) -> bool:
@@ -71,7 +78,9 @@ def support_amp(device: torch.device, cfg: DictConfig) -> bool:
return cfg.training.use_amp and device.type in ("cuda", "cpu")
def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg):
def train_epoch(
model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg
):
# Single epoch training loop with AMP support and progress tracking
model.train()
correct = 0
@@ -85,7 +94,11 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device,
labels = batch[cfg.training.label_key].float().to(device)
# Forward pass with optional AMP
with torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext():
with (
torch.autocast(device_type=device.type)
if support_amp(device, cfg)
else nullcontext()
):
outputs = model(images)
loss = criterion(outputs.logits, labels)
@@ -130,7 +143,9 @@ def validate(model, val_loader, criterion, device, logger, cfg):
with (
torch.no_grad(),
torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(),
torch.autocast(device_type=device.type)
if support_amp(device, cfg)
else nullcontext(),
):
for batch in tqdm(val_loader, desc="Validation"):
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
@@ -143,7 +158,9 @@ def validate(model, val_loader, criterion, device, logger, cfg):
):
outputs = model(images)
inference_times.append(
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
next(
x for x in prof.key_averages() if x.key == "model_inference"
).cpu_time
)
else:
outputs = model(images)
@@ -161,16 +178,24 @@ def validate(model, val_loader, criterion, device, logger, cfg):
# Log sample predictions for visualization
if len(samples) < cfg.eval.num_samples_to_log:
for i in range(min(cfg.eval.num_samples_to_log - len(samples), len(images))):
for i in range(
min(cfg.eval.num_samples_to_log - len(samples), len(images))
):
if model.config.num_classes == 2:
confidence = round(outputs.probabilities[i].item(), 3)
else:
confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()]
confidence = [
round(prob, 3) for prob in outputs.probabilities[i].tolist()
]
samples.append(
{
**{
f"image_{img_key}": wandb.Image(images[img_idx][i].cpu())
for img_idx, img_key in enumerate(cfg.training.image_keys)
f"image_{img_key}": wandb.Image(
images[img_idx][i].cpu()
)
for img_idx, img_key in enumerate(
cfg.training.image_keys
)
},
"true_label": labels[i].item(),
"predicted": predictions[i].item(),
@@ -238,15 +263,24 @@ def benchmark_inference_time(model, dataset, logger, cfg, device, step):
elif device.type == "mps":
torch.mps.synchronize()
with profiler.profile(record_shapes=True) as prof, profiler.record_function("model_inference"):
with (
profiler.profile(record_shapes=True) as prof,
profiler.record_function("model_inference"),
):
_ = model(x)
inference_times.append(
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
next(
x for x in prof.key_averages() if x.key == "model_inference"
).cpu_time
)
inference_times = np.array(inference_times)
avg, median, std = inference_times.mean(), np.median(inference_times), inference_times.std()
avg, median, std = (
inference_times.mean(),
np.median(inference_times),
inference_times.std(),
)
print(
f"Inference time mean: {avg:.2f} us, median: {median:.2f} us, std: {std:.2f} us, with {iters} iterations on {device.type} device"
)
@@ -264,7 +298,11 @@ def benchmark_inference_time(model, dataset, logger, cfg, device, step):
return avg, median, std
@hydra.main(version_base="1.2", config_path="../configs/policy", config_name="hilserl_classifier")
@hydra.main(
version_base="1.2",
config_path="../configs/policy",
config_name="hilserl_classifier",
)
def train(cfg: DictConfig) -> None:
# Main training pipeline with support for resuming training
logging.info(OmegaConf.to_yaml(cfg))
@@ -278,7 +316,9 @@ def train(cfg: DictConfig) -> None:
# Setup dataset and dataloaders
dataset = LeRobotDataset(
cfg.dataset_repo_id, root=cfg.dataset_root, local_files_only=cfg.local_files_only
cfg.dataset_repo_id,
root=cfg.dataset_root,
local_files_only=cfg.local_files_only,
)
logging.info(f"Dataset size: {len(dataset)}")
@@ -314,7 +354,9 @@ def train(cfg: DictConfig) -> None:
"You have set resume=True, but there is no model checkpoint in "
f"{Logger.get_last_checkpoint_dir(out_dir)}"
)
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
checkpoint_cfg_path = str(
Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml"
)
logging.info(
colored(
"You have set resume=True, indicating that you wish to resume a run",
@@ -327,7 +369,9 @@ def train(cfg: DictConfig) -> None:
# Check for differences between the checkpoint configuration and provided configuration.
# Hack to resolve the delta_timestamps ahead of time in order to properly diff.
resolve_delta_timestamps(cfg)
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
diff = DeepDiff(
OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg)
)
# Ignore the `resume` and parameters.
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
del diff["values_changed"]["root['resume']"]
@@ -346,7 +390,11 @@ def train(cfg: DictConfig) -> None:
optimizer = optim.AdamW(model.parameters(), lr=cfg.training.learning_rate)
# Use BCEWithLogitsLoss for binary classification and CrossEntropyLoss for multi-class
criterion = nn.BCEWithLogitsLoss() if model.config.num_classes == 2 else nn.CrossEntropyLoss()
criterion = (
nn.BCEWithLogitsLoss()
if model.config.num_classes == 2
else nn.CrossEntropyLoss()
)
grad_scaler = GradScaler(enabled=cfg.training.use_amp)
# Log model parameters
@@ -362,7 +410,17 @@ def train(cfg: DictConfig) -> None:
for epoch in range(cfg.training.num_epochs):
logging.info(f"\nEpoch {epoch+1}/{cfg.training.num_epochs}")
train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg)
train_epoch(
model,
train_loader,
criterion,
optimizer,
grad_scaler,
device,
logger,
step,
cfg,
)
# Periodic validation
if cfg.training.eval_freq > 0 and (epoch + 1) % cfg.training.eval_freq == 0:

View File

@@ -22,7 +22,6 @@ from typing import Callable, Optional, Sequence, TypedDict
import hydra
import torch
import torch.nn.functional as F
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from torch import nn
from tqdm import tqdm
@@ -30,20 +29,17 @@ from tqdm import tqdm
# TODO: Remove the import of maniskill
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.envs.factory import make_env, make_maniskill_env
from lerobot.common.envs.utils import preprocess_maniskill_observation, preprocess_observation
from lerobot.common.envs.factory import make_maniskill_env
from lerobot.common.envs.utils import preprocess_maniskill_observation
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.utils import (
format_big_number,
get_safe_torch_device,
init_hydra_config,
init_logging,
set_global_seed,
)
from lerobot.scripts.eval import eval_policy
def make_optimizers_and_scheduler(cfg, policy):
@@ -56,7 +52,9 @@ def make_optimizers_and_scheduler(cfg, policy):
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
)
# We wrap policy log temperature in list because this is a torch tensor and not a nn.Module
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
optimizer_temperature = torch.optim.Adam(
params=[policy.log_alpha], lr=policy.config.critic_lr
)
lr_scheduler = None
optimizers = {
"actor": optimizer_actor,
@@ -108,7 +106,9 @@ def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Te
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
# Gather pixels
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
cropped_hwcn = images_hwcn[
torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :
]
# cropped_hwcn => (B, crop_h, crop_w, C)
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
@@ -198,8 +198,12 @@ class ReplayBuffer:
"""
# We convert the LeRobotDataset into a replay buffer, because it is more efficient to sample from
# a replay buffer than from a lerobot dataset.
replay_buffer = cls(capacity=len(lerobot_dataset), device=device, state_keys=state_keys)
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
replay_buffer = cls(
capacity=len(lerobot_dataset), device=device, state_keys=state_keys
)
list_transition = cls._lerobotdataset_to_transitions(
dataset=lerobot_dataset, state_keys=state_keys
)
# Fill the replay buffer with the lerobot dataset transitions
for data in list_transition:
replay_buffer.add(
@@ -244,7 +248,9 @@ class ReplayBuffer:
# If not provided, you can either raise an error or define a default:
if state_keys is None:
raise ValueError("You must provide a list of keys in `state_keys` that define your 'state'.")
raise ValueError(
"You must provide a list of keys in `state_keys` that define your 'state'."
)
transitions: list[Transition] = []
num_frames = len(dataset)
@@ -298,36 +304,40 @@ class ReplayBuffer:
# -- Build batched states --
batch_state = {}
for key in self.state_keys:
batch_state[key] = torch.cat([t["state"][key] for t in list_of_transitions], dim=0).to(
self.device
)
batch_state[key] = torch.cat(
[t["state"][key] for t in list_of_transitions], dim=0
).to(self.device)
if key.startswith("observation.image") and self.use_drq:
batch_state[key] = self.image_augmentation_function(batch_state[key])
# -- Build batched actions --
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(self.device)
# -- Build batched rewards --
batch_rewards = torch.tensor([t["reward"] for t in list_of_transitions], dtype=torch.float32).to(
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(
self.device
)
# -- Build batched rewards --
batch_rewards = torch.tensor(
[t["reward"] for t in list_of_transitions], dtype=torch.float32
).to(self.device)
# -- Build batched next states --
batch_next_state = {}
for key in self.state_keys:
batch_next_state[key] = torch.cat([t["next_state"][key] for t in list_of_transitions], dim=0).to(
self.device
)
batch_next_state[key] = torch.cat(
[t["next_state"][key] for t in list_of_transitions], dim=0
).to(self.device)
if key.startswith("observation.image") and self.use_drq:
batch_next_state[key] = self.image_augmentation_function(batch_next_state[key])
batch_next_state[key] = self.image_augmentation_function(
batch_next_state[key]
)
# -- Build batched dones --
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
self.device
)
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
self.device
)
batch_dones = torch.tensor(
[t["done"] for t in list_of_transitions], dtype=torch.float32
).to(self.device)
batch_dones = torch.tensor(
[t["done"] for t in list_of_transitions], dtype=torch.float32
).to(self.device)
# Return a BatchTransition typed dict
return BatchTransition(
@@ -344,7 +354,13 @@ def concatenate_batch_transitions(
) -> BatchTransition:
"""NOTE: Be careful it change the left_batch_transitions in place"""
left_batch_transitions["state"] = {
key: torch.cat([left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0)
key: torch.cat(
[
left_batch_transitions["state"][key],
right_batch_transition["state"][key],
],
dim=0,
)
for key in left_batch_transitions["state"]
}
left_batch_transitions["action"] = torch.cat(
@@ -355,7 +371,11 @@ def concatenate_batch_transitions(
)
left_batch_transitions["next_state"] = {
key: torch.cat(
[left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]], dim=0
[
left_batch_transitions["next_state"][key],
right_batch_transition["next_state"][key],
],
dim=0,
)
for key in left_batch_transitions["next_state"]
}
@@ -407,7 +427,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# Hack: But if we do online traning, we do not need dataset_stats
dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir)
if cfg.resume
else None,
device=device,
)
assert isinstance(policy, nn.Module)
@@ -416,7 +438,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# TODO: Handle resume
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_learnable_params = sum(
p.numel() for p in policy.parameters() if p.requires_grad
)
num_total_params = sum(p.numel() for p in policy.parameters())
log_output_dir(out_dir)
@@ -433,7 +457,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
obs = {key: obs[key].to(device, non_blocking=True) for key in obs}
replay_buffer = ReplayBuffer(
capacity=cfg.training.online_buffer_capacity, device=device, state_keys=cfg.policy.input_shapes.keys()
capacity=cfg.training.online_buffer_capacity,
device=device,
state_keys=cfg.policy.input_shapes.keys(),
)
batch_size = cfg.training.batch_size
@@ -455,12 +481,16 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
if interaction_step >= cfg.training.online_step_before_learning:
action = policy.select_action(batch=obs)
next_obs, reward, done, truncated, info = online_env.step(action.cpu().numpy())
next_obs, reward, done, truncated, info = online_env.step(
action.cpu().numpy()
)
else:
action = online_env.action_space.sample()
next_obs, reward, done, truncated, info = online_env.step(action)
# HACK
action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True)
action = torch.tensor(action, dtype=torch.float32).to(
device, non_blocking=True
)
# HACK: For maniskill
# next_obs = preprocess_observation(next_obs)
@@ -470,14 +500,20 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# Because we are using a single environment
# we can safely assume that the episode is done
if done[0] or truncated[0]:
logging.info(f"Global step {interaction_step}: Episode reward: {sum_reward_episode}")
logger.log_dict({"Sum episode reward": sum_reward_episode}, interaction_step)
logging.info(
f"Global step {interaction_step}: Episode reward: {sum_reward_episode}"
)
logger.log_dict(
{"Sum episode reward": sum_reward_episode}, interaction_step
)
sum_reward_episode = 0
# HACK: This is for maniskill
logging.info(
f"global step {interaction_step}: episode success: {info['success'].float().item()} \n"
)
logger.log_dict({"Episode success": info["success"].float().item()}, interaction_step)
logger.log_dict(
{"Episode success": info["success"].float().item()}, interaction_step
)
replay_buffer.add(
state=obs,
@@ -551,7 +587,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
training_infos["loss_actor"] = loss_actor.item()
loss_temperature = policy.compute_loss_temperature(observations=observations)
loss_temperature = policy.compute_loss_temperature(
observations=observations
)
optimizers["temperature"].zero_grad()
loss_temperature.backward()
optimizers["temperature"].step()
@@ -573,7 +611,9 @@ def train_cli(cfg: dict):
)
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
def train_notebook(
out_dir=None, job_name=None, config_name="default", config_path="../configs"
):
from hydra import compose, initialize
hydra.core.global_hydra.GlobalHydra.instance().clear()

View File

@@ -94,8 +94,12 @@ def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
assert chw_float32_torch.dtype == torch.float32
assert chw_float32_torch.ndim == 3
c, h, w = chw_float32_torch.shape
assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}"
hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
assert (
c < h and c < w
), f"expect channel first images, but instead {chw_float32_torch.shape}"
hwc_uint8_numpy = (
(chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
)
return hwc_uint8_numpy

View File

@@ -81,7 +81,11 @@ def run_server(
static_folder: Path,
template_folder: Path,
):
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
app = Flask(
__name__,
static_folder=static_folder.resolve(),
template_folder=template_folder.resolve(),
)
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
@app.route("/")
@@ -138,8 +142,12 @@ def run_server(
)
)
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes):
@app.route(
"/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>"
)
def show_episode(
dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes
):
repo_id = f"{dataset_namespace}/{dataset_name}"
try:
if dataset is None:
@@ -150,7 +158,9 @@ def run_server(
400,
)
dataset_version = (
str(dataset.meta._version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
str(dataset.meta._version)
if isinstance(dataset, LeRobotDataset)
else dataset.codebase_version
)
match = re.search(r"v(\d+)\.", dataset_version)
if match:
@@ -158,7 +168,9 @@ def run_server(
if major_version < 2:
return "Make sure to convert your LeRobotDataset to v2 & above."
episode_data_csv_str, columns, ignored_columns = get_episode_data(dataset, episode_id)
episode_data_csv_str, columns, ignored_columns = get_episode_data(
dataset, episode_id
)
dataset_info = {
"repo_id": f"{dataset_namespace}/{dataset_name}",
"num_samples": dataset.num_frames
@@ -171,18 +183,23 @@ def run_server(
}
if isinstance(dataset, LeRobotDataset):
video_paths = [
dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys
dataset.meta.get_video_file_path(episode_id, key)
for key in dataset.meta.video_keys
]
videos_info = [
{
"url": url_for("static", filename=str(video_path).replace("\\", "/")),
"url": url_for(
"static", filename=str(video_path).replace("\\", "/")
),
"filename": video_path.parent.name,
}
for video_path in video_paths
]
tasks = dataset.meta.episodes[episode_id]["tasks"]
else:
video_keys = [key for key, ft in dataset.features.items() if ft["dtype"] == "video"]
video_keys = [
key for key, ft in dataset.features.items() if ft["dtype"] == "video"
]
videos_info = [
{
"url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
@@ -197,20 +214,29 @@ def run_server(
]
response = requests.get(
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl", timeout=5
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl",
timeout=5,
)
response.raise_for_status()
# Split into lines and parse each line as JSON
tasks_jsonl = [json.loads(line) for line in response.text.splitlines() if line.strip()]
tasks_jsonl = [
json.loads(line) for line in response.text.splitlines() if line.strip()
]
filtered_tasks_jsonl = [row for row in tasks_jsonl if row["episode_index"] == episode_id]
filtered_tasks_jsonl = [
row for row in tasks_jsonl if row["episode_index"] == episode_id
]
tasks = filtered_tasks_jsonl[0]["tasks"]
videos_info[0]["language_instruction"] = tasks
if episodes is None:
episodes = list(
range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes)
range(
dataset.num_episodes
if isinstance(dataset, LeRobotDataset)
else dataset.total_episodes
)
)
return render_template(
@@ -237,7 +263,11 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
This file will be loaded by Dygraph javascript to plot data in real time."""
columns = []
selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] in ["float32", "int32"]]
selected_columns = [
col
for col, ft in dataset.features.items()
if ft["dtype"] in ["float32", "int32"]
]
selected_columns.remove("timestamp")
ignored_columns = []
@@ -258,7 +288,10 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
else dataset.features[column_name].shape[0]
)
if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]:
if (
"names" in dataset.features[column_name]
and dataset.features[column_name]["names"]
):
column_names = dataset.features[column_name]["names"]
while not isinstance(column_names, list):
column_names = list(column_names.values())[0]
@@ -281,8 +314,12 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
else:
repo_id = dataset.repo_id
url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format(
episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index
url = (
f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
+ dataset.data_path.format(
episode_chunk=int(episode_index) // dataset.chunks_size,
episode_index=episode_index,
)
)
df = pd.read_parquet(url)
data = df[selected_columns] # Select specific columns
@@ -315,7 +352,9 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
]
def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]:
def get_episode_language_instruction(
dataset: LeRobotDataset, ep_index: int
) -> list[str]:
# check if the dataset has language instructions
if "language_instruction" not in dataset.features:
return None
@@ -326,12 +365,15 @@ def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) ->
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
# with the tf.tensor appearing in the string
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)")
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix(
"', shape=(), dtype=string)"
)
def get_dataset_info(repo_id: str) -> IterableNamespace:
response = requests.get(
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json", timeout=5
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json",
timeout=5,
)
response.raise_for_status() # Raises an HTTPError for bad responses
dataset_info = response.json()
@@ -361,7 +403,9 @@ def visualize_dataset_html(
if force_override:
shutil.rmtree(output_dir)
else:
logging.info(f"Output directory already exists. Loading from it: '{output_dir}'")
logging.info(
f"Output directory already exists. Loading from it: '{output_dir}'"
)
output_dir.mkdir(parents=True, exist_ok=True)