[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-24 13:41:27 +00:00
committed by Michel Aractingi
parent 2abbd60a0d
commit 0ea27704f6
123 changed files with 1161 additions and 3425 deletions

View File

@@ -67,8 +67,8 @@ def get_motor_bus_cls(brand: str) -> tuple:
def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
motor_bus_config_cls, motor_bus_cls, model_baudrate_table, series_baudrate_table = (
get_motor_bus_cls(brand)
motor_bus_config_cls, motor_bus_cls, model_baudrate_table, series_baudrate_table = get_motor_bus_cls(
brand
)
# Check if the provided model exists in the model_baud_rate_table
@@ -82,9 +82,7 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
motor_index_arbitrary = motor_idx_des # Use the motor ID passed via argument
motor_model = model # Use the motor model passed via argument
config = motor_bus_config_cls(
port=port, motors={motor_name: (motor_index_arbitrary, motor_model)}
)
config = motor_bus_config_cls(port=port, motors={motor_name: (motor_index_arbitrary, motor_model)})
# Initialize the MotorBus with the correct port and motor configurations
motor_bus = motor_bus_cls(config=config)
@@ -120,26 +118,20 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
break
if motor_index == -1:
raise ValueError(
"No motors detected. Please ensure you have one motor connected."
)
raise ValueError("No motors detected. Please ensure you have one motor connected.")
print(f"Motor index found at: {motor_index}")
if brand == "feetech":
# Allows ID and BAUDRATE to be written in memory
motor_bus.write_with_motor_ids(
motor_bus.motor_models, motor_index, "Lock", 0
)
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0)
if baudrate != baudrate_des:
print(f"Setting its baudrate to {baudrate_des}")
baudrate_idx = list(series_baudrate_table.values()).index(baudrate_des)
# The write can fail, so we allow retries
motor_bus.write_with_motor_ids(
motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx
)
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx)
time.sleep(0.5)
motor_bus.set_bus_baudrate(baudrate_des)
present_baudrate_idx = motor_bus.read_with_motor_ids(
@@ -151,16 +143,10 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
print(f"Setting its index to desired index {motor_idx_des}")
if brand == "feetech":
motor_bus.write_with_motor_ids(
motor_bus.motor_models, motor_index, "Lock", 0
)
motor_bus.write_with_motor_ids(
motor_bus.motor_models, motor_index, "ID", motor_idx_des
)
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0)
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "ID", motor_idx_des)
present_idx = motor_bus.read_with_motor_ids(
motor_bus.motor_models, motor_idx_des, "ID", num_retry=2
)
present_idx = motor_bus.read_with_motor_ids(motor_bus.motor_models, motor_idx_des, "ID", num_retry=2)
if present_idx != motor_idx_des:
raise OSError("Failed to write index.")
@@ -194,12 +180,8 @@ if __name__ == "__main__":
required=True,
help="Motors bus port (e.g. dynamixel,feetech)",
)
parser.add_argument(
"--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)"
)
parser.add_argument(
"--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)"
)
parser.add_argument("--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)")
parser.add_argument("--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)")
parser.add_argument(
"--ID",
type=int,

View File

@@ -255,8 +255,7 @@ def record(
if len(robot.cameras) > 0:
dataset.start_image_writer(
num_processes=cfg.num_image_writer_processes,
num_threads=cfg.num_image_writer_threads_per_camera
* len(robot.cameras),
num_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
)
sanity_check_dataset_robot_compatibility(dataset, robot, cfg.fps, cfg.video)
else:
@@ -269,19 +268,14 @@ def record(
robot=robot,
use_videos=cfg.video,
image_writer_processes=cfg.num_image_writer_processes,
image_writer_threads=cfg.num_image_writer_threads_per_camera
* len(robot.cameras),
image_writer_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
)
# Load pretrained policy
policy = (
None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
)
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
# Load pretrained policy
policy = (
None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
)
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
if not robot.is_connected:
robot.connect()

View File

@@ -174,10 +174,7 @@ def teleoperate(env, robot: Robot, process_action_fn, teleop_time_s=None):
leader_pos = robot.leader_arms.main.read("Present_Position")
action = process_action_fn(leader_pos)
env.step(np.expand_dims(action, 0))
if (
teleop_time_s is not None
and time.perf_counter() - start_teleop_t > teleop_time_s
):
if teleop_time_s is not None and time.perf_counter() - start_teleop_t > teleop_time_s:
print("Teleoperation processes finished.")
break
@@ -209,27 +206,19 @@ def record(
# Load pretrained policy
extra_features = (
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}}
if assign_rewards
else None
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None
)
policy = None
if pretrained_policy_name_or_path is not None:
policy, policy_fps, device, use_amp = init_policy(
pretrained_policy_name_or_path, policy_overrides
)
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
if fps is None:
fps = policy_fps
logging.warning(
f"No fps provided, so using the fps from policy config ({policy_fps})."
)
logging.warning(f"No fps provided, so using the fps from policy config ({policy_fps}).")
if policy is None and process_action_from_leader is None:
raise ValueError(
"Either policy or process_action_fn has to be set to enable control in sim."
)
raise ValueError("Either policy or process_action_fn has to be set to enable control in sim.")
# initialize listener before sim env
listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
@@ -380,9 +369,7 @@ def record(
if events["stop_recording"] or recorded_episodes >= num_episodes:
break
else:
logging.info(
"Waiting for a few seconds before starting next episode recording..."
)
logging.info("Waiting for a few seconds before starting next episode recording...")
busy_wait(3)
log_say("Stop recording", play_sounds, blocking=True)
@@ -481,9 +468,7 @@ if __name__ == "__main__":
required=True,
help="A description of the task preformed during recording that can be used as a language instruction.",
)
parser_record.add_argument(
"--num-episodes", type=int, default=50, help="Number of episodes to record."
)
parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.")
parser_record.add_argument(
"--run-compute-stats",
type=int,
@@ -561,9 +546,7 @@ if __name__ == "__main__":
default="lerobot/test",
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
)
parser_replay.add_argument(
"--episode", type=int, default=0, help="Index of the episodes to replay."
)
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episodes to replay.")
args = parser.parse_args()

View File

@@ -59,11 +59,7 @@ np_version = np.__version__ if HAS_NP else "N/A"
torch_version = torch.__version__ if HAS_TORCH else "N/A"
torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A"
cuda_version = (
torch._C._cuda_getCompiledVersion()
if HAS_TORCH and torch.version.cuda is not None
else "N/A"
)
cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A"
# TODO(aliberts): refactor into an actual command `lerobot env`
@@ -81,9 +77,7 @@ def display_sys_info() -> dict:
"Using GPU in script?": "<fill in>",
# "Using distributed or parallel set-up in script?": "<fill in>",
}
print(
"\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n"
)
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n")
print(format_dict(info))
return info

View File

@@ -152,8 +152,7 @@ def rollout(
all_observations.append(deepcopy(observation))
observation = {
key: observation[key].to(device, non_blocking=device.type == "cuda")
for key in observation
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
}
# Infer "task" from attributes of environments.
@@ -175,10 +174,7 @@ def rollout(
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
# available of none of the envs finished.
if "final_info" in info:
successes = [
info["is_success"] if info is not None else False
for info in info["final_info"]
]
successes = [info["is_success"] if info is not None else False for info in info["final_info"]]
else:
successes = [False] * env.num_envs
@@ -192,13 +188,9 @@ def rollout(
step += 1
running_success_rate = (
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any")
.numpy()
.mean()
)
progbar.set_postfix(
{"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"}
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean()
)
progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"})
progbar.update()
# Track the final observation.
@@ -216,9 +208,7 @@ def rollout(
if return_observations:
stacked_observations = {}
for key in all_observations[0]:
stacked_observations[key] = torch.stack(
[obs[key] for obs in all_observations], dim=1
)
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
ret["observation"] = stacked_observations
if hasattr(policy, "use_original_modules"):
@@ -280,9 +270,7 @@ def eval_policy(
return
n_to_render_now = min(max_episodes_rendered - n_episodes_rendered, env.num_envs)
if isinstance(env, gym.vector.SyncVectorEnv):
ep_frames.append(
np.stack([env.envs[i].render() for i in range(n_to_render_now)])
) # noqa: B023
ep_frames.append(np.stack([env.envs[i].render() for i in range(n_to_render_now)])) # noqa: B023
elif isinstance(env, gym.vector.AsyncVectorEnv):
# Here we must render all frames and discard any we don't need.
ep_frames.append(np.stack(env.call("render")[:n_to_render_now]))
@@ -294,9 +282,7 @@ def eval_policy(
episode_data: dict | None = None
# we dont want progress bar when we use slurm, since it clutters the logs
progbar = trange(
n_batches, desc="Stepping through eval batches", disable=inside_slurm()
)
progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm())
for batch_ix in progbar:
# Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout
# step.
@@ -326,22 +312,13 @@ def eval_policy(
# Make a mask with shape (batch, n_steps) to mask out rollout data after the first done
# (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step.
mask = (
torch.arange(n_steps)
<= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)
).int()
mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int()
# Extend metrics.
batch_sum_rewards = einops.reduce(
(rollout_data["reward"] * mask), "b n -> b", "sum"
)
batch_sum_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "sum")
sum_rewards.extend(batch_sum_rewards.tolist())
batch_max_rewards = einops.reduce(
(rollout_data["reward"] * mask), "b n -> b", "max"
)
batch_max_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "max")
max_rewards.extend(batch_max_rewards.tolist())
batch_successes = einops.reduce(
(rollout_data["success"] * mask), "b n -> b", "any"
)
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
all_successes.extend(batch_successes.tolist())
if seeds:
all_seeds.extend(seeds)
@@ -354,27 +331,17 @@ def eval_policy(
rollout_data,
done_indices,
start_episode_index=batch_ix * env.num_envs,
start_data_index=(
0
if episode_data is None
else (episode_data["index"][-1].item() + 1)
),
start_data_index=(0 if episode_data is None else (episode_data["index"][-1].item() + 1)),
fps=env.unwrapped.metadata["render_fps"],
)
if episode_data is None:
episode_data = this_episode_data
else:
# Some sanity checks to make sure we are correctly compiling the data.
assert (
episode_data["episode_index"][-1] + 1
== this_episode_data["episode_index"][0]
)
assert episode_data["episode_index"][-1] + 1 == this_episode_data["episode_index"][0]
assert episode_data["index"][-1] + 1 == this_episode_data["index"][0]
# Concatenate the episode data.
episode_data = {
k: torch.cat([episode_data[k], this_episode_data[k]])
for k in episode_data
}
episode_data = {k: torch.cat([episode_data[k], this_episode_data[k]]) for k in episode_data}
# Maybe render video for visualization.
if max_episodes_rendered > 0 and len(ep_frames) > 0:
@@ -392,9 +359,7 @@ def eval_policy(
target=write_video,
args=(
str(video_path),
stacked_frames[
: done_index + 1
], # + 1 to capture the last observation
stacked_frames[: done_index + 1], # + 1 to capture the last observation
env.unwrapped.metadata["render_fps"],
),
)
@@ -403,9 +368,7 @@ def eval_policy(
n_episodes_rendered += 1
progbar.set_postfix(
{
"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"
}
{"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"}
)
# Wait till all video rendering threads are done.
@@ -473,16 +436,12 @@ def _compile_episode_data(
# Here we do `num_frames - 1` as we don't want to include the last observation frame just yet.
ep_dict = {
"action": rollout_data["action"][ep_ix, : num_frames - 1],
"episode_index": torch.tensor(
[start_episode_index + ep_ix] * (num_frames - 1)
),
"episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)),
"frame_index": torch.arange(0, num_frames - 1, 1),
"timestamp": torch.arange(0, num_frames - 1, 1) / fps,
"next.done": rollout_data["done"][ep_ix, : num_frames - 1],
"next.success": rollout_data["success"][ep_ix, : num_frames - 1],
"next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(
torch.float32
),
"next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32),
}
# For the last observation frame, all other keys will just be copy padded.
@@ -498,9 +457,7 @@ def _compile_episode_data(
for key in ep_dicts[0]:
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
data_dict["index"] = torch.arange(
start_data_index, start_data_index + total_frames, 1
)
data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1)
return data_dict
@@ -516,14 +473,10 @@ def eval_main(cfg: EvalPipelineConfig):
torch.backends.cuda.matmul.allow_tf32 = True
set_seed(cfg.seed)
logging.info(
colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}"
)
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
logging.info("Making environment.")
env = make_env(
cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs
)
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
logging.info("Making policy.")
@@ -535,9 +488,7 @@ def eval_main(cfg: EvalPipelineConfig):
with (
torch.no_grad(),
torch.autocast(device_type=device.type)
if cfg.policy.use_amp
else nullcontext(),
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
):
info = eval_policy(
env,

View File

@@ -74,9 +74,7 @@ def get_classifier(pretrained_path, config_path):
cfg = init_hydra_config(config_path)
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
classifier_config.num_cameras = len(
cfg.training.image_keys
) # TODO automate these paths
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
model = Classifier(classifier_config)
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
model = model.to("mps")
@@ -161,17 +159,11 @@ def rollout(
images = []
for key in image_keys:
if display_cameras:
cv2.imshow(
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
)
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
images.append(observation[key].to("mps"))
reward = (
reward_classifier.predict_reward(images)
if reward_classifier is not None
else 0.0
)
reward = reward_classifier.predict_reward(images) if reward_classifier is not None else 0.0
all_rewards.append(reward)
# print("REWARD : ", reward)
@@ -235,9 +227,7 @@ def eval_policy(
start_eval = time.perf_counter()
progbar = trange(n_episodes, desc="Evaluating policy on real robot")
reward_classifier = get_classifier(
reward_classifier_pretrained_path, reward_classifier_config_file
)
reward_classifier = get_classifier(reward_classifier_pretrained_path, reward_classifier_config_file)
for _ in progbar:
rollout_data = rollout(
@@ -313,9 +303,7 @@ def init_keyboard_listener():
print("Right arrow key pressed. Exiting loop...")
events["exit_early"] = True
elif key == keyboard.Key.left:
print(
"Left arrow key pressed. Exiting loop and rerecord the last episode..."
)
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.space:
@@ -380,9 +368,7 @@ if __name__ == "__main__":
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
),
)
parser.add_argument(
"--revision", help="Optionally provide the Hugging Face Hub revision ID."
)
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
parser.add_argument(
"--out-dir",
help=(

View File

@@ -45,13 +45,9 @@ def find_port():
print(f"The port of this MotorsBus is '{port}'")
print("Reconnect the USB cable.")
elif len(ports_diff) == 0:
raise OSError(
f"Could not detect the port. No difference was found ({ports_diff})."
)
raise OSError(f"Could not detect the port. No difference was found ({ports_diff}).")
else:
raise OSError(
f"Could not detect the port. More than one port was found ({ports_diff})."
)
raise OSError(f"Could not detect the port. More than one port was found ({ports_diff}).")
if __name__ == "__main__":

View File

@@ -14,18 +14,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from statistics import mean, quantiles
import time
from functools import lru_cache
from lerobot.scripts.server.utils import setup_process_handlers
from queue import Empty
from statistics import mean, quantiles
# from lerobot.scripts.eval import eval_policy
import grpc
import hydra
import torch
from omegaconf import DictConfig
from torch import nn
import time
from torch.multiprocessing import Event, Queue
# TODO: Remove the import of maniskill
# from lerobot.common.envs.factory import make_maniskill_env
@@ -34,34 +34,28 @@ from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.utils import (
TimerManager,
get_safe_torch_device,
init_logging,
set_global_seed,
)
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc, learner_service
from lerobot.scripts.server.buffer import (
Transition,
bytes_to_state_dict,
move_state_dict_to_device,
move_transition_to_device,
python_object_to_bytes,
transitions_to_bytes,
bytes_to_state_dict,
)
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
from lerobot.scripts.server.network_utils import (
receive_bytes_in_chunks,
send_bytes_in_chunks,
)
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
from lerobot.scripts.server import learner_service
from lerobot.common.robot_devices.utils import busy_wait
from torch.multiprocessing import Queue, Event
from queue import Empty
from lerobot.common.utils.utils import init_logging
from lerobot.scripts.server.utils import get_last_item_from_queue
from lerobot.scripts.server.utils import get_last_item_from_queue, setup_process_handlers
ACTOR_SHUTDOWN_TIMEOUT = 30
@@ -102,9 +96,7 @@ def receive_policy(
logging.info("[ACTOR] Received policy loop stopped")
def transitions_stream(
shutdown_event: Event, transitions_queue: Queue
) -> hilserl_pb2.Empty:
def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilserl_pb2.Empty:
while not shutdown_event.is_set():
try:
message = transitions_queue.get(block=True, timeout=5)
@@ -169,9 +161,7 @@ def send_transitions(
)
try:
learner_client.SendTransitions(
transitions_stream(shutdown_event, transitions_queue)
)
learner_client.SendTransitions(transitions_stream(shutdown_event, transitions_queue))
except grpc.RpcError as e:
logging.error(f"[ACTOR] gRPC error: {e}")
@@ -211,9 +201,7 @@ def send_interactions(
)
try:
learner_client.SendInteractions(
interactions_stream(shutdown_event, interactions_queue)
)
learner_client.SendInteractions(interactions_stream(shutdown_event, interactions_queue))
except grpc.RpcError as e:
logging.error(f"[ACTOR] gRPC error: {e}")
@@ -301,9 +289,7 @@ def act_with_policy(
logging.info("make_env online")
online_env = make_robot_env(
robot=robot, reward_classifier=reward_classifier, cfg=cfg
)
online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg)
set_global_seed(cfg.seed)
device = get_safe_torch_device(cfg.device, log=True)
@@ -355,13 +341,9 @@ def act_with_policy(
action = policy.select_action(batch=obs)
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
log_policy_frequency_issue(
policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step
)
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
next_obs, reward, done, truncated, info = online_env.step(
action.squeeze(dim=0).cpu().numpy()
)
next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy())
else:
# TODO (azouitine): Make a custom space for torch tensor
action = online_env.action_space.sample()
@@ -369,9 +351,7 @@ def act_with_policy(
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
action = (
torch.from_numpy(action[0])
.to(device, non_blocking=device.type == "cuda")
.unsqueeze(dim=0)
torch.from_numpy(action[0]).to(device, non_blocking=device.type == "cuda").unsqueeze(dim=0)
)
sum_reward_episode += float(reward)
@@ -391,9 +371,7 @@ def act_with_policy(
# Check for NaN values in observations
for key, tensor in obs.items():
if torch.isnan(tensor).any():
logging.error(
f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}"
)
logging.error(f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}")
list_transition_to_send_to_learner.append(
Transition(
@@ -413,13 +391,9 @@ def act_with_policy(
# Because we are using a single environment we can index at zero
if done or truncated:
# TODO: Handle logging for episode information
logging.info(
f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}"
)
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
update_policy_parameters(
policy=policy.actor, parameters_queue=parameters_queue, device=device
)
update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device)
if len(list_transition_to_send_to_learner) > 0:
push_transitions_to_transport_queue(
@@ -495,9 +469,7 @@ def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
return stats
def log_policy_frequency_issue(
policy_fps: float, cfg: DictConfig, interaction_step: int
):
def log_policy_frequency_issue(policy_fps: float, cfg: DictConfig, interaction_step: int):
if policy_fps < cfg.fps:
logging.warning(
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}"

View File

@@ -14,16 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import io
import os
import pickle
from typing import Any, Callable, Optional, Sequence, TypedDict
import io
import torch
import torch.nn.functional as F # noqa: N812
from tqdm import tqdm
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
import os
import pickle
class Transition(TypedDict):
@@ -45,38 +45,27 @@ class BatchTransition(TypedDict):
truncated: torch.Tensor
def move_transition_to_device(
transition: Transition, device: str = "cpu"
) -> Transition:
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
# Move state tensors to CPU
device = torch.device(device)
transition["state"] = {
key: val.to(device, non_blocking=device.type == "cuda")
for key, val in transition["state"].items()
key: val.to(device, non_blocking=device.type == "cuda") for key, val in transition["state"].items()
}
# Move action to CPU
transition["action"] = transition["action"].to(
device, non_blocking=device.type == "cuda"
)
transition["action"] = transition["action"].to(device, non_blocking=device.type == "cuda")
# No need to move reward or done, as they are float and bool
# No need to move reward or done, as they are float and bool
if isinstance(transition["reward"], torch.Tensor):
transition["reward"] = transition["reward"].to(
device=device, non_blocking=device.type == "cuda"
)
transition["reward"] = transition["reward"].to(device=device, non_blocking=device.type == "cuda")
if isinstance(transition["done"], torch.Tensor):
transition["done"] = transition["done"].to(
device, non_blocking=device.type == "cuda"
)
transition["done"] = transition["done"].to(device, non_blocking=device.type == "cuda")
if isinstance(transition["truncated"], torch.Tensor):
transition["truncated"] = transition["truncated"].to(
device, non_blocking=device.type == "cuda"
)
transition["truncated"] = transition["truncated"].to(device, non_blocking=device.type == "cuda")
# Move next_state tensors to CPU
transition["next_state"] = {
@@ -100,10 +89,7 @@ def move_state_dict_to_device(state_dict, device="cpu"):
if isinstance(state_dict, torch.Tensor):
return state_dict.to(device)
elif isinstance(state_dict, dict):
return {
k: move_state_dict_to_device(v, device=device)
for k, v in state_dict.items()
}
return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()}
elif isinstance(state_dict, list):
return [move_state_dict_to_device(v, device=device) for v in state_dict]
elif isinstance(state_dict, tuple):
@@ -174,9 +160,7 @@ def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Te
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
# Gather pixels
cropped_hwcn = images_hwcn[
torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :
]
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
# cropped_hwcn => (B, crop_h, crop_w, C)
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
@@ -223,9 +207,7 @@ class ReplayBuffer:
self.optimize_memory = optimize_memory
# Track episode boundaries for memory optimization
self.episode_ends = torch.zeros(
capacity, dtype=torch.bool, device=storage_device
)
self.episode_ends = torch.zeros(capacity, dtype=torch.bool, device=storage_device)
# If no state_keys provided, default to an empty list
self.state_keys = state_keys if state_keys is not None else []
@@ -246,9 +228,7 @@ class ReplayBuffer:
key: torch.empty((self.capacity, *shape), device=self.storage_device)
for key, shape in state_shapes.items()
}
self.actions = torch.empty(
(self.capacity, *action_shape), device=self.storage_device
)
self.actions = torch.empty((self.capacity, *action_shape), device=self.storage_device)
self.rewards = torch.empty((self.capacity,), device=self.storage_device)
if not self.optimize_memory:
@@ -262,12 +242,8 @@ class ReplayBuffer:
# Just create a reference to states for consistent API
self.next_states = self.states # Just a reference for API consistency
self.dones = torch.empty(
(self.capacity,), dtype=torch.bool, device=self.storage_device
)
self.truncateds = torch.empty(
(self.capacity,), dtype=torch.bool, device=self.storage_device
)
self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
self.initialized = True
def __len__(self):
@@ -294,9 +270,7 @@ class ReplayBuffer:
if not self.optimize_memory:
# Only store next_states if not optimizing memory
self.next_states[key][self.position].copy_(
next_state[key].squeeze(dim=0)
)
self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0))
self.actions[self.position].copy_(action.squeeze(dim=0))
self.rewards[self.position] = reward
@@ -309,23 +283,15 @@ class ReplayBuffer:
def sample(self, batch_size: int) -> BatchTransition:
"""Sample a random batch of transitions and collate them into batched tensors."""
if not self.initialized:
raise RuntimeError(
"Cannot sample from an empty buffer. Add transitions first."
)
raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.")
batch_size = min(batch_size, self.size)
# Random indices for sampling - create on the same device as storage
idx = torch.randint(
low=0, high=self.size, size=(batch_size,), device=self.storage_device
)
idx = torch.randint(low=0, high=self.size, size=(batch_size,), device=self.storage_device)
# Identify image keys that need augmentation
image_keys = (
[k for k in self.states if k.startswith("observation.image")]
if self.use_drq
else []
)
image_keys = [k for k in self.states if k.startswith("observation.image")] if self.use_drq else []
# Create batched state and next_state
batch_state = {}
@@ -358,13 +324,9 @@ class ReplayBuffer:
# Split the augmented images back to their sources
for i, key in enumerate(image_keys):
# State images are at even indices (0, 2, 4...)
batch_state[key] = augmented_images[
i * 2 * batch_size : (i * 2 + 1) * batch_size
]
batch_state[key] = augmented_images[i * 2 * batch_size : (i * 2 + 1) * batch_size]
# Next state images are at odd indices (1, 3, 5...)
batch_next_state[key] = augmented_images[
(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size
]
batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size]
# Sample other tensors
batch_actions = self.actions[idx].to(self.device)
@@ -434,16 +396,12 @@ class ReplayBuffer:
)
# Convert dataset to transitions
list_transition = cls._lerobotdataset_to_transitions(
dataset=lerobot_dataset, state_keys=state_keys
)
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
# Initialize the buffer with the first transition to set up storage tensors
if list_transition:
first_transition = list_transition[0]
first_state = {
k: v.to(device) for k, v in first_transition["state"].items()
}
first_state = {k: v.to(device) for k, v in first_transition["state"].items()}
first_action = first_transition["action"].to(device)
# Apply action mask/delta if needed
@@ -541,9 +499,7 @@ class ReplayBuffer:
# Convert transitions into episodes and frames
episode_index = 0
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
episode_index=episode_index
)
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index=episode_index)
frame_idx_in_episode = 0
for idx in range(self.size):
@@ -557,12 +513,8 @@ class ReplayBuffer:
# Fill action, reward, done
frame_dict["action"] = self.actions[actual_idx].cpu()
frame_dict["next.reward"] = torch.tensor(
[self.rewards[actual_idx]], dtype=torch.float32
).cpu()
frame_dict["next.done"] = torch.tensor(
[self.dones[actual_idx]], dtype=torch.bool
).cpu()
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
# Add to the dataset's buffer
lerobot_dataset.add_frame(frame_dict)
@@ -619,9 +571,7 @@ class ReplayBuffer:
A list of Transition dictionaries with the same length as `dataset`.
"""
if state_keys is None:
raise ValueError(
"State keys must be provided when converting LeRobotDataset to Transitions."
)
raise ValueError("State keys must be provided when converting LeRobotDataset to Transitions.")
transitions = []
num_frames = len(dataset)
@@ -632,9 +582,7 @@ class ReplayBuffer:
# If not, we need to infer it from episode boundaries
if not has_done_key:
print(
"'next.done' key not found in dataset. Inferring from episode boundaries..."
)
print("'next.done' key not found in dataset. Inferring from episode boundaries...")
for i in tqdm(range(num_frames)):
current_sample = dataset[i]
@@ -886,8 +834,7 @@ if __name__ == "__main__":
# We need to be careful because we don't know the original index
# So we check if the increment is roughly 0.01
next_state_check = (
abs(next_state_sig - state_sig - 0.01) < 1e-4
or abs(next_state_sig - state_sig) < 1e-4
abs(next_state_sig - state_sig - 0.01) < 1e-4 or abs(next_state_sig - state_sig) < 1e-4
)
# Count correct relationships
@@ -901,17 +848,11 @@ if __name__ == "__main__":
total_checks += 3
alignment_accuracy = 100.0 * correct_relationships / total_checks
print(
f"State-action-reward-next_state alignment accuracy: {alignment_accuracy:.2f}%"
)
print(f"State-action-reward-next_state alignment accuracy: {alignment_accuracy:.2f}%")
if alignment_accuracy > 99.0:
print(
"✅ All relationships verified! Buffer maintains correct temporal relationships."
)
print("✅ All relationships verified! Buffer maintains correct temporal relationships.")
else:
print(
"⚠️ Some relationships don't match expected patterns. Buffer may have alignment issues."
)
print("⚠️ Some relationships don't match expected patterns. Buffer may have alignment issues.")
# Print some debug information about failures
print("\nDebug information for failed checks:")
@@ -973,18 +914,14 @@ if __name__ == "__main__":
# Verify consistency before and after conversion
original_states = batch["state"]["observation.image"].mean().item()
reconverted_states = (
reconverted_batch["state"]["observation.image"].mean().item()
)
reconverted_states = reconverted_batch["state"]["observation.image"].mean().item()
print(f"Original buffer state mean: {original_states:.4f}")
print(f"Reconverted buffer state mean: {reconverted_states:.4f}")
if abs(original_states - reconverted_states) < 1.0:
print("Values are reasonably similar - conversion works as expected")
else:
print(
"WARNING: Significant difference between original and reconverted values"
)
print("WARNING: Significant difference between original and reconverted values")
print("\nAll previous tests completed!")
@@ -1093,15 +1030,11 @@ if __name__ == "__main__":
all_indices = torch.arange(sequential_batch_size, device=test_buffer.storage_device)
# Get state tensors
batch_state = {
"value": test_buffer.states["value"][all_indices].to(test_buffer.device)
}
batch_state = {"value": test_buffer.states["value"][all_indices].to(test_buffer.device)}
# Get next_state using memory-optimized approach (simply index+1)
next_indices = (all_indices + 1) % test_buffer.capacity
batch_next_state = {
"value": test_buffer.states["value"][next_indices].to(test_buffer.device)
}
batch_next_state = {"value": test_buffer.states["value"][next_indices].to(test_buffer.device)}
# Get other tensors
batch_dones = test_buffer.dones[all_indices].to(test_buffer.device)
@@ -1121,9 +1054,7 @@ if __name__ == "__main__":
print("- We always use the next state in the buffer (index+1) as next_state")
print("- For terminal states, this means using the first state of the next episode")
print("- This is a common tradeoff in RL implementations for memory efficiency")
print(
"- Since we track done flags, the algorithm can handle these transitions correctly"
)
print("- Since we track done flags, the algorithm can handle these transitions correctly")
# Test random sampling
print("\nVerifying random sampling with simplified memory optimization...")
@@ -1137,23 +1068,19 @@ if __name__ == "__main__":
# Print a few samples
print("Random samples - State, Next State, Done (First 10):")
for i in range(10):
print(
f" {random_state_values[i]:.1f}{random_next_values[i]:.1f}, Done: {random_done_flags[i]}"
)
print(f" {random_state_values[i]:.1f}{random_next_values[i]:.1f}, Done: {random_done_flags[i]}")
# Calculate memory savings
# Assume optimized_buffer and standard_buffer have already been initialized and filled
std_mem = (
sum(
standard_buffer.states[key].nelement()
* standard_buffer.states[key].element_size()
standard_buffer.states[key].nelement() * standard_buffer.states[key].element_size()
for key in standard_buffer.states
)
* 2
)
opt_mem = sum(
optimized_buffer.states[key].nelement()
* optimized_buffer.states[key].element_size()
optimized_buffer.states[key].nelement() * optimized_buffer.states[key].element_size()
for key in optimized_buffer.states
)

View File

@@ -225,9 +225,7 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Crop rectangular ROIs from a LeRobot dataset."
)
parser = argparse.ArgumentParser(description="Crop rectangular ROIs from a LeRobot dataset.")
parser.add_argument(
"--repo-id",
type=str,
@@ -249,9 +247,7 @@ if __name__ == "__main__":
args = parser.parse_args()
local_files_only = args.root is not None
dataset = LeRobotDataset(
repo_id=args.repo_id, root=args.root, local_files_only=local_files_only
)
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root, local_files_only=local_files_only)
images = get_image_from_lerobot_dataset(dataset)
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}

View File

@@ -1,13 +1,14 @@
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.utils.utils import init_hydra_config
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.scripts.server.kinematics import RobotKinematics
import argparse
import logging
import time
import torch
import numpy as np
import argparse
import numpy as np
import torch
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.utils import init_hydra_config
from lerobot.scripts.server.kinematics import RobotKinematics
logging.basicConfig(level=logging.INFO)
@@ -187,9 +188,7 @@ class KeyboardController(InputController):
class GamepadController(InputController):
"""Generate motion deltas from gamepad input."""
def __init__(
self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01, deadzone=0.1
):
def __init__(self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01, deadzone=0.1):
super().__init__(x_step_size, y_step_size, z_step_size)
self.deadzone = deadzone
self.joystick = None
@@ -203,9 +202,7 @@ class GamepadController(InputController):
pygame.joystick.init()
if pygame.joystick.get_count() == 0:
logging.error(
"No gamepad detected. Please connect a gamepad and try again."
)
logging.error("No gamepad detected. Please connect a gamepad and try again.")
self.running = False
return
@@ -338,18 +335,12 @@ class GamepadControllerHID(InputController):
devices = hid.enumerate()
for device in devices:
if (
device["vendor_id"] == self.vendor_id
and device["product_id"] == self.product_id
):
logging.info(
f"Found gamepad: {device.get('product_string', 'Unknown')}"
)
if device["vendor_id"] == self.vendor_id and device["product_id"] == self.product_id:
logging.info(f"Found gamepad: {device.get('product_string', 'Unknown')}")
return device
logging.error(
f"No gamepad with vendor ID 0x{self.vendor_id:04X} and "
f"product ID 0x{self.product_id:04X} found"
f"No gamepad with vendor ID 0x{self.vendor_id:04X} and product ID 0x{self.product_id:04X} found"
)
return None
@@ -381,9 +372,7 @@ class GamepadControllerHID(InputController):
except OSError as e:
logging.error(f"Error opening gamepad: {e}")
logging.error(
"You might need to run this with sudo/admin privileges on some systems"
)
logging.error("You might need to run this with sudo/admin privileges on some systems")
self.running = False
def stop(self):
@@ -421,12 +410,8 @@ class GamepadControllerHID(InputController):
# Apply deadzone
self.left_x = 0 if abs(self.left_x) < self.deadzone else self.left_x
self.left_y = 0 if abs(self.left_y) < self.deadzone else self.left_y
self.right_x = (
0 if abs(self.right_x) < self.deadzone else self.right_x
)
self.right_y = (
0 if abs(self.right_y) < self.deadzone else self.right_y
)
self.right_x = 0 if abs(self.right_x) < self.deadzone else self.right_x
self.right_y = 0 if abs(self.right_y) < self.deadzone else self.right_y
# Parse button states (byte 5 in the Logitech RumblePad 2)
buttons = data[5]
@@ -493,9 +478,7 @@ def test_inverse_kinematics(robot, fps=10):
joint_positions = obs["observation.state"].cpu().numpy()
ee_pos = RobotKinematics.fk_gripper_tip(joint_positions)
desired_ee_pos = ee_pos
target_joint_state = RobotKinematics.ik(
joint_positions, desired_ee_pos, position_only=True
)
target_joint_state = RobotKinematics.ik(joint_positions, desired_ee_pos, position_only=True)
robot.send_action(torch.from_numpy(target_joint_state))
logging.info(f"Target Joint State: {target_joint_state}")
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
@@ -573,17 +556,13 @@ def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
robot.send_action(torch.from_numpy(target_joint_state))
# Logging
logging.info(
f"Current EE: {current_ee_pos[:3, 3]}, Desired EE: {desired_ee_pos[:3, 3]}"
)
logging.info(f"Current EE: {current_ee_pos[:3, 3]}, Desired EE: {desired_ee_pos[:3, 3]}")
logging.info(f"Delta EE: {ee_delta[:3, 3]}")
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
def teleoperate_delta_inverse_kinematics(
robot, controller, fps=10, bounds=None, fk_func=None
):
def teleoperate_delta_inverse_kinematics(robot, controller, fps=10, bounds=None, fk_func=None):
"""
Control a robot using delta end-effector movements from any input controller.
@@ -597,9 +576,7 @@ def teleoperate_delta_inverse_kinematics(
if fk_func is None:
fk_func = RobotKinematics.fk_gripper_tip
logging.info(
f"Testing Delta End-Effector Control with {controller.__class__.__name__}"
)
logging.info(f"Testing Delta End-Effector Control with {controller.__class__.__name__}")
# Initial position capture
obs = robot.capture_observation()
@@ -631,9 +608,7 @@ def teleoperate_delta_inverse_kinematics(
# Apply bounds if provided
if bounds is not None:
desired_ee_pos[:3, 3] = np.clip(
desired_ee_pos[:3, 3], bounds["min"], bounds["max"]
)
desired_ee_pos[:3, 3] = np.clip(desired_ee_pos[:3, 3], bounds["min"], bounds["max"])
# Only send commands if there's actual movement
if any([abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]]):
@@ -684,14 +659,10 @@ def teleoperate_gym_env(env, controller, fps: int = 30):
if any([abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]]):
# Step the environment - pass action as a tensor with intervention flag
action_tensor = torch.from_numpy(action.astype(np.float32))
obs, reward, terminated, truncated, info = env.step(
(action_tensor, False)
)
obs, reward, terminated, truncated, info = env.step((action_tensor, False))
# Log information
logging.info(
f"Action: [{delta_x:.4f}, {delta_y:.4f}, {delta_z:.4f}]"
)
logging.info(f"Action: [{delta_x:.4f}, {delta_y:.4f}, {delta_z:.4f}]")
logging.info(f"Reward: {reward}")
# Reset if episode ended
@@ -761,20 +732,14 @@ if __name__ == "__main__":
# Determine controller type based on mode prefix
controller = None
if args.mode.startswith("keyboard"):
controller = KeyboardController(
x_step_size=0.01, y_step_size=0.01, z_step_size=0.05
)
controller = KeyboardController(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05)
elif args.mode.startswith("gamepad"):
controller = GamepadController(
x_step_size=0.02, y_step_size=0.02, z_step_size=0.05
)
controller = GamepadController(x_step_size=0.02, y_step_size=0.02, z_step_size=0.05)
# Handle mode categories
if args.mode in ["keyboard", "gamepad"]:
# Direct robot control modes
teleoperate_delta_inverse_kinematics(
robot, controller, bounds=bounds, fps=10
)
teleoperate_delta_inverse_kinematics(robot, controller, bounds=bounds, fps=10)
elif args.mode in ["keyboard_gym", "gamepad_gym"]:
# Gym environment control modes

View File

@@ -32,9 +32,7 @@ def find_joint_bounds(
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
)
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
if time.perf_counter() - start_episode_t > control_time_s:
@@ -69,9 +67,7 @@ def find_ee_bounds(
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
for key in image_keys:
cv2.imshow(
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
)
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
if time.perf_counter() - start_episode_t > control_time_s:

View File

@@ -1,10 +1,10 @@
import argparse
import sys
import logging
import sys
import time
from threading import Lock
from typing import Annotated, Any, Dict, Tuple
import gymnasium as gym
import numpy as np
import torch
@@ -18,7 +18,6 @@ from lerobot.common.robot_devices.control_utils import (
)
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.utils.utils import init_hydra_config, log_say
from lerobot.scripts.server.kinematics import RobotKinematics
logging.basicConfig(level=logging.INFO)
@@ -67,9 +66,7 @@ class HILSerlRobotEnv(gym.Env):
if not self.robot.is_connected:
self.robot.connect()
self.initial_follower_position = robot.follower_arms["main"].read(
"Present_Position"
)
self.initial_follower_position = robot.follower_arms["main"].read("Present_Position")
# Episode tracking.
self.current_step = 0
@@ -77,9 +74,7 @@ class HILSerlRobotEnv(gym.Env):
self.delta = delta
self.use_delta_action_space = use_delta_action_space
self.current_joint_positions = self.robot.follower_arms["main"].read(
"Present_Position"
)
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
# Retrieve the size of the joint position interval bound.
self.relative_bounds_size = (
@@ -92,9 +87,7 @@ class HILSerlRobotEnv(gym.Env):
)
self.robot.config.max_relative_target = (
self.relative_bounds_size.float()
if self.relative_bounds_size is not None
else None
self.relative_bounds_size.float() if self.relative_bounds_size is not None else None
)
# Dynamically configure the observation and action spaces.
@@ -119,9 +112,7 @@ class HILSerlRobotEnv(gym.Env):
# Define observation spaces for images and other states.
image_keys = [key for key in example_obs if "image" in key]
observation_spaces = {
key: gym.spaces.Box(
low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8
)
key: gym.spaces.Box(low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8)
for key in image_keys
}
observation_spaces["observation.state"] = gym.spaces.Box(
@@ -172,9 +163,7 @@ class HILSerlRobotEnv(gym.Env):
),
)
def reset(
self, seed=None, options=None
) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
def reset(self, seed=None, options=None) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
"""
Reset the environment to its initial state.
This method resets the step counter and clears any episodic data.
@@ -231,35 +220,25 @@ class HILSerlRobotEnv(gym.Env):
"""
policy_action, intervention_bool = action
teleop_action = None
self.current_joint_positions = self.robot.follower_arms["main"].read(
"Present_Position"
)
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
if isinstance(policy_action, torch.Tensor):
policy_action = policy_action.cpu().numpy()
policy_action = np.clip(
policy_action, self.action_space[0].low, self.action_space[0].high
)
policy_action = np.clip(policy_action, self.action_space[0].low, self.action_space[0].high)
if not intervention_bool:
if self.use_delta_action_space:
target_joint_positions = (
self.current_joint_positions + self.delta * policy_action
)
target_joint_positions = self.current_joint_positions + self.delta * policy_action
else:
target_joint_positions = policy_action
self.robot.send_action(torch.from_numpy(target_joint_positions))
observation = self.robot.capture_observation()
else:
observation, teleop_action = self.robot.teleop_step(record_data=True)
teleop_action = teleop_action[
"action"
] # Convert tensor to appropriate format
teleop_action = teleop_action["action"] # Convert tensor to appropriate format
# When applying the delta action space, convert teleop absolute values to relative differences.
if self.use_delta_action_space:
teleop_action = (
teleop_action - self.current_joint_positions
) / self.delta
teleop_action = (teleop_action - self.current_joint_positions) / self.delta
if self.relative_bounds_size is not None and (
torch.any(teleop_action < -self.relative_bounds_size)
and torch.any(teleop_action > self.relative_bounds_size)
@@ -333,12 +312,8 @@ class AddJointVelocityToObservation(gym.ObservationWrapper):
self.last_joint_positions = np.zeros(old_shape)
new_low = np.concatenate(
[old_low, np.ones_like(old_low) * -joint_velocity_limits]
)
new_high = np.concatenate(
[old_high, np.ones_like(old_high) * joint_velocity_limits]
)
new_low = np.concatenate([old_low, np.ones_like(old_low) * -joint_velocity_limits])
new_high = np.concatenate([old_high, np.ones_like(old_high) * joint_velocity_limits])
new_shape = (old_shape[0] * 2,)
@@ -352,9 +327,7 @@ class AddJointVelocityToObservation(gym.ObservationWrapper):
self.dt = 1.0 / fps
def observation(self, observation):
joint_velocities = (
observation["observation.state"] - self.last_joint_positions
) / self.dt
joint_velocities = (observation["observation.state"] - self.last_joint_positions) / self.dt
self.last_joint_positions = observation["observation.state"].clone()
observation["observation.state"] = torch.cat(
[observation["observation.state"], joint_velocities], dim=-1
@@ -439,9 +412,7 @@ class JointMaskingActionSpace(gym.Wrapper):
raise ValueError("Mask length must match action space dimensions")
low = env.action_space.low[self.active_dims]
high = env.action_space.high[self.active_dims]
self.action_space = gym.spaces.Box(
low=low, high=high, dtype=env.action_space.dtype
)
self.action_space = gym.spaces.Box(low=low, high=high, dtype=env.action_space.dtype)
if isinstance(env.action_space, gym.spaces.Tuple):
if len(mask) != env.action_space[0].shape[0]:
@@ -449,12 +420,8 @@ class JointMaskingActionSpace(gym.Wrapper):
low = env.action_space[0].low[self.active_dims]
high = env.action_space[0].high[self.active_dims]
action_space_masked = gym.spaces.Box(
low=low, high=high, dtype=env.action_space[0].dtype
)
self.action_space = gym.spaces.Tuple(
(action_space_masked, env.action_space[1])
)
action_space_masked = gym.spaces.Box(low=low, high=high, dtype=env.action_space[0].dtype)
self.action_space = gym.spaces.Tuple((action_space_masked, env.action_space[1]))
# Create new action space with masked dimensions
def action(self, action):
@@ -473,18 +440,14 @@ class JointMaskingActionSpace(gym.Wrapper):
# Extract the masked component from the tuple.
masked_action = action[0] if isinstance(action, tuple) else action
# Create a full action for the Box element.
full_box_action = np.zeros(
self.env.action_space[0].shape, dtype=self.env.action_space[0].dtype
)
full_box_action = np.zeros(self.env.action_space[0].shape, dtype=self.env.action_space[0].dtype)
full_box_action[self.active_dims] = masked_action
# Return a tuple with the reconstructed Box action and the unchanged remainder.
return (full_box_action, action[1])
else:
# For Box action spaces.
masked_action = action if not isinstance(action, tuple) else action[0]
full_action = np.zeros(
self.env.action_space.shape, dtype=self.env.action_space.dtype
)
full_action = np.zeros(self.env.action_space.shape, dtype=self.env.action_space.dtype)
full_action[self.active_dims] = masked_action
return full_action
@@ -493,13 +456,9 @@ class JointMaskingActionSpace(gym.Wrapper):
obs, reward, terminated, truncated, info = self.env.step(action)
if "action_intervention" in info and info["action_intervention"] is not None:
if info["action_intervention"].dim() == 1:
info["action_intervention"] = info["action_intervention"][
self.active_dims
]
info["action_intervention"] = info["action_intervention"][self.active_dims]
else:
info["action_intervention"] = info["action_intervention"][
:, self.active_dims
]
info["action_intervention"] = info["action_intervention"][:, self.active_dims]
return obs, reward, terminated, truncated, info
@@ -555,9 +514,7 @@ class ImageCropResizeWrapper(gym.Wrapper):
for key in crop_params_dict:
top, left, height, width = crop_params_dict[key]
new_shape = (top + height, left + width)
self.observation_space[key] = gym.spaces.Box(
low=0, high=255, shape=new_shape
)
self.observation_space[key] = gym.spaces.Box(low=0, high=255, shape=new_shape)
self.resize_size = resize_size
if self.resize_size is None:
@@ -583,9 +540,7 @@ class ImageCropResizeWrapper(gym.Wrapper):
)
# Check for NaNs before processing
if torch.isnan(obs[k]).any():
logging.error(
f"NaN values detected in observation {k} before crop and resize"
)
logging.error(f"NaN values detected in observation {k} before crop and resize")
if device == torch.device("mps:0"):
obs[k] = obs[k].cpu()
@@ -595,9 +550,7 @@ class ImageCropResizeWrapper(gym.Wrapper):
# Check for NaNs after processing
if torch.isnan(obs[k]).any():
logging.error(
f"NaN values detected in observation {k} after crop and resize"
)
logging.error(f"NaN values detected in observation {k} after crop and resize")
obs[k] = obs[k].to(device)
@@ -627,14 +580,10 @@ class ConvertToLeRobotObservation(gym.ObservationWrapper):
observation = preprocess_observation(observation)
observation = {
key: observation[key].to(
self.device, non_blocking=self.device.type == "cuda"
)
key: observation[key].to(self.device, non_blocking=self.device.type == "cuda")
for key in observation
}
observation = {
k: torch.tensor(v, device=self.device) for k, v in observation.items()
}
observation = {k: torch.tensor(v, device=self.device) for k, v in observation.items()}
return observation
@@ -686,26 +635,16 @@ class KeyboardInterfaceWrapper(gym.Wrapper):
play_sounds=True,
)
return
if (
self.events["pause_policy"]
and not self.events["human_intervention_step"]
):
if self.events["pause_policy"] and not self.events["human_intervention_step"]:
self.events["human_intervention_step"] = True
print("Space key pressed. Human intervention starting.")
log_say(
"Starting human intervention.", play_sounds=True
)
log_say("Starting human intervention.", play_sounds=True)
return
if (
self.events["pause_policy"]
and self.events["human_intervention_step"]
):
if self.events["pause_policy"] and self.events["human_intervention_step"]:
self.events["pause_policy"] = False
self.events["human_intervention_step"] = False
print("Space key pressed for a third time.")
log_say(
"Continuing with policy actions.", play_sounds=True
)
log_say("Continuing with policy actions.", play_sounds=True)
return
except Exception as e:
print(f"Error handling key press: {e}")
@@ -713,9 +652,7 @@ class KeyboardInterfaceWrapper(gym.Wrapper):
self.listener = keyboard.Listener(on_press=on_press)
self.listener.start()
except ImportError:
logging.warning(
"Could not import pynput. Keyboard interface will not be available."
)
logging.warning("Could not import pynput. Keyboard interface will not be available.")
self.listener = None
def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]:
@@ -742,9 +679,7 @@ class KeyboardInterfaceWrapper(gym.Wrapper):
time.sleep(0.1) # Check more frequently if desired
# Execute the step in the underlying environment
obs, reward, terminated, truncated, info = self.env.step(
(policy_action, is_intervention)
)
obs, reward, terminated, truncated, info = self.env.step((policy_action, is_intervention))
# Override reward and termination if episode success event triggered
with self.event_lock:
@@ -807,9 +742,7 @@ class BatchCompitableWrapper(gym.ObservationWrapper):
def __init__(self, env):
super().__init__(env)
def observation(
self, observation: dict[str, torch.Tensor]
) -> dict[str, torch.Tensor]:
def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
for key in observation:
if "image" in key and observation[key].dim() == 3:
observation[key] = observation[key].unsqueeze(0)
@@ -844,9 +777,7 @@ class EEActionWrapper(gym.ActionWrapper):
dtype=np.float32,
)
if isinstance(self.action_space, gym.spaces.Tuple):
self.action_space = gym.spaces.Tuple(
(ee_action_space, self.action_space[1])
)
self.action_space = gym.spaces.Tuple((ee_action_space, self.action_space[1]))
else:
self.action_space = ee_action_space
@@ -858,9 +789,7 @@ class EEActionWrapper(gym.ActionWrapper):
if isinstance(action, tuple):
action, _ = action
current_joint_pos = self.unwrapped.robot.follower_arms["main"].read(
"Present_Position"
)
current_joint_pos = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
current_ee_pos = self.fk_function(current_joint_pos)
if isinstance(action, torch.Tensor):
action = action.cpu().numpy()
@@ -898,9 +827,7 @@ class EEObservationWrapper(gym.ObservationWrapper):
self.fk_function = self.kinematics.fk_gripper_tip
def observation(self, observation):
current_joint_pos = self.unwrapped.robot.follower_arms["main"].read(
"Present_Position"
)
current_joint_pos = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
current_ee_pos = self.fk_function(current_joint_pos)
observation["observation.state"] = torch.cat(
[
@@ -944,8 +871,8 @@ class GamepadControlWrapper(gym.Wrapper):
"""
super().__init__(env)
from lerobot.scripts.server.end_effector_control_utils import (
GamepadControllerHID,
GamepadController,
GamepadControllerHID,
)
# use HidApi for macos
@@ -1027,9 +954,7 @@ class GamepadControlWrapper(gym.Wrapper):
# Update episode ending state if requested
if terminate_episode:
logging.info(
f"Episode manually ended: {'SUCCESS' if success else 'FAILURE'}"
)
logging.info(f"Episode manually ended: {'SUCCESS' if success else 'FAILURE'}")
# Only override the action if gamepad is active
if is_intervention:
@@ -1054,9 +979,7 @@ class GamepadControlWrapper(gym.Wrapper):
logging.info("Episode ended successfully with reward 1.0")
info["is_intervention"] = is_intervention
action_intervention = (
final_action[0] if isinstance(final_action, Tuple) else final_action
)
action_intervention = final_action[0] if isinstance(final_action, Tuple) else final_action
if isinstance(action_intervention, np.ndarray):
action_intervention = torch.from_numpy(action_intervention)
info["action_intervention"] = action_intervention
@@ -1087,9 +1010,7 @@ class GamepadControlWrapper(gym.Wrapper):
class ActionScaleWrapper(gym.ActionWrapper):
def __init__(self, env, ee_action_space_params=None):
super().__init__(env)
assert ee_action_space_params is not None, (
"TODO: method implemented for ee action space only so far"
)
assert ee_action_space_params is not None, "TODO: method implemented for ee action space only so far"
self.scale_vector = np.array(
[
[
@@ -1148,9 +1069,7 @@ def make_robot_env(
if cfg.env.wrapper.add_joint_velocity_to_observation:
env = AddJointVelocityToObservation(env=env, fps=cfg.fps)
if cfg.env.wrapper.add_ee_pose_to_observation:
env = EEObservationWrapper(
env=env, ee_pose_limits=cfg.env.wrapper.ee_action_space_params.bounds
)
env = EEObservationWrapper(env=env, ee_pose_limits=cfg.env.wrapper.ee_action_space_params.bounds)
env = ConvertToLeRobotObservation(env=env, device=cfg.env.device)
@@ -1163,13 +1082,9 @@ def make_robot_env(
# Add reward computation and control wrappers
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
env = TimeLimitWrapper(
env=env, control_time_s=cfg.env.wrapper.control_time_s, fps=cfg.fps
)
env = TimeLimitWrapper(env=env, control_time_s=cfg.env.wrapper.control_time_s, fps=cfg.fps)
if cfg.env.wrapper.ee_action_space_params is not None:
env = EEActionWrapper(
env=env, ee_action_space_params=cfg.env.wrapper.ee_action_space_params
)
env = EEActionWrapper(env=env, ee_action_space_params=cfg.env.wrapper.ee_action_space_params)
if (
cfg.env.wrapper.ee_action_space_params is not None
and cfg.env.wrapper.ee_action_space_params.use_gamepad
@@ -1193,9 +1108,7 @@ def make_robot_env(
cfg.env.wrapper.ee_action_space_params is None
and cfg.env.wrapper.joint_masking_action_space is not None
):
env = JointMaskingActionSpace(
env=env, mask=cfg.env.wrapper.joint_masking_action_space
)
env = JointMaskingActionSpace(env=env, mask=cfg.env.wrapper.joint_masking_action_space)
env = BatchCompitableWrapper(env=env)
return env
@@ -1216,9 +1129,7 @@ def get_classifier(pretrained_path, config_path, device="mps"):
cfg = init_hydra_config(config_path)
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
classifier_config.num_cameras = len(
cfg.training.image_keys
) # TODO automate these paths
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
model = Classifier(classifier_config)
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
model = model.to(device)
@@ -1317,9 +1228,7 @@ def record_dataset(
# For teleop, get action from intervention
if policy is None:
action = {
"action": info["action_intervention"].cpu().squeeze(0).float()
}
action = {"action": info["action_intervention"].cpu().squeeze(0).float()}
# Process observation for dataset
obs = {k: v.cpu().squeeze(0).float() for k, v in obs.items()}
@@ -1357,9 +1266,7 @@ def replay_episode(env, repo_id, root=None, episode=0):
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
local_files_only = root is not None
dataset = LeRobotDataset(
repo_id, root=root, episodes=[episode], local_files_only=local_files_only
)
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
env.reset()
actions = dataset.hf_dataset.select_columns("action")
@@ -1414,9 +1321,7 @@ if __name__ == "__main__":
default=None,
help="Path to a yaml config file that is necessary to build the reward classifier model.",
)
parser.add_argument(
"--env-path", type=str, default=None, help="Path to the env yaml file"
)
parser.add_argument("--env-path", type=str, default=None, help="Path to the env yaml file")
parser.add_argument(
"--env-overrides",
type=str,
@@ -1441,12 +1346,8 @@ if __name__ == "__main__":
default=None,
help="Repo ID of the episode to replay",
)
parser.add_argument(
"--dataset-root", type=str, default=None, help="Root of the dataset to replay"
)
parser.add_argument(
"--replay-episode", type=int, default=0, help="Episode to replay"
)
parser.add_argument("--dataset-root", type=str, default=None, help="Root of the dataset to replay")
parser.add_argument("--replay-episode", type=int, default=0, help="Episode to replay")
parser.add_argument(
"--record-repo-id",
type=str,
@@ -1534,9 +1435,7 @@ if __name__ == "__main__":
smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
# Execute the step: wrap the NumPy action in a torch tensor.
obs, reward, terminated, truncated, info = env.step(
(torch.from_numpy(smoothed_action), False)
)
obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False))
if terminated or truncated:
sucesses.append(reward)
env.reset()

View File

@@ -23,11 +23,7 @@ def screw_axis_to_transform(S, theta):
elif np.linalg.norm(S_w) == 1: # Rotation and translation
w_hat = skew_symmetric(S_w)
R = np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat
t = (
np.eye(3) * theta
+ (1 - np.cos(theta)) * w_hat
+ (theta - np.sin(theta)) * w_hat @ w_hat
) @ S_v
t = (np.eye(3) * theta + (1 - np.cos(theta)) * w_hat + (theta - np.sin(theta)) * w_hat @ w_hat) @ S_v
T = np.eye(4)
T[:3, :3] = R
T[:3, 3] = t
@@ -189,9 +185,7 @@ class RobotKinematics:
# Wrist
# Screw axis of wrist frame wrt base frame
self.S_BR = np.array(
[0, 1, 0, -self.measurements["wrist"][2], 0, self.measurements["wrist"][0]]
)
self.S_BR = np.array([0, 1, 0, -self.measurements["wrist"][2], 0, self.measurements["wrist"][0]])
# 0-position origin to centroid transform
self.X_RoRc = self._create_translation_matrix(x=0.0035, y=-0.002)
@@ -284,12 +278,7 @@ class RobotKinematics:
def fk_shoulder(self, robot_pos_deg):
"""Forward kinematics for the shoulder frame."""
robot_pos_rad = robot_pos_deg / 180 * np.pi
return (
self.X_WoBo
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
@ self.X_SoSc
@ self.X_BS
)
return self.X_WoBo @ screw_axis_to_transform(self.S_BS, robot_pos_rad[0]) @ self.X_SoSc @ self.X_BS
def fk_humerus(self, robot_pos_deg):
"""Forward kinematics for the humerus frame."""
@@ -403,15 +392,12 @@ class RobotKinematics:
delta *= 0
delta[el_ix] = eps / 2
Sdot = (
fk_func(robot_pos_deg[:-1] + delta)[:3, 3]
- fk_func(robot_pos_deg[:-1] - delta)[:3, 3]
fk_func(robot_pos_deg[:-1] + delta)[:3, 3] - fk_func(robot_pos_deg[:-1] - delta)[:3, 3]
) / eps
jac[:, el_ix] = Sdot
return jac
def ik(
self, current_joint_state, desired_ee_pose, position_only=True, fk_func=None
):
def ik(self, current_joint_state, desired_ee_pose, position_only=True, fk_func=None):
"""Inverse kinematics using gradient descent.
Args:
@@ -457,9 +443,7 @@ if __name__ == "__main__":
# Test 1: Forward kinematics consistency
print("Test 1: Forward kinematics consistency")
test_angles = np.array(
[30, 45, -30, 20, 10, 0]
) # Example joint angles in degrees
test_angles = np.array([30, 45, -30, 20, 10, 0]) # Example joint angles in degrees
# Calculate FK for different joints
shoulder_pose = robot.fk_shoulder(test_angles)
@@ -480,13 +464,9 @@ if __name__ == "__main__":
]
# Check if distances generally increase along the chain
is_consistent = all(
distances[i] <= distances[i + 1] for i in range(len(distances) - 1)
)
is_consistent = all(distances[i] <= distances[i + 1] for i in range(len(distances) - 1))
print(f" Pose distances from origin: {[round(d, 3) for d in distances]}")
print(
f" Kinematic chain consistency: {'PASSED' if is_consistent else 'FAILED'}"
)
print(f" Kinematic chain consistency: {'PASSED' if is_consistent else 'FAILED'}")
# Test 2: Jacobian computation
print("Test 2: Jacobian computation")
@@ -498,9 +478,7 @@ if __name__ == "__main__":
pos_jacobian_shape_ok = positional_jacobian.shape == (3, 5)
print(f" Jacobian shape: {'PASSED' if jacobian_shape_ok else 'FAILED'}")
print(
f" Positional Jacobian shape: {'PASSED' if pos_jacobian_shape_ok else 'FAILED'}"
)
print(f" Positional Jacobian shape: {'PASSED' if pos_jacobian_shape_ok else 'FAILED'}")
# Test 3: Inverse kinematics
print("Test 3: Inverse kinematics (position only)")

View File

@@ -17,15 +17,8 @@
import logging
import shutil
import time
from pprint import pformat
from concurrent.futures import ThreadPoolExecutor
# from torch.multiprocessing import Event, Queue, Process
# from threading import Event, Thread
# from torch.multiprocessing import Queue, Event
from torch.multiprocessing import Queue
from lerobot.scripts.server.utils import setup_process_handlers
from pprint import pformat
import grpc
@@ -37,6 +30,11 @@ from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
from torch import nn
# from torch.multiprocessing import Event, Queue, Process
# from threading import Event, Thread
# from torch.multiprocessing import Queue, Event
from torch.multiprocessing import Queue
from torch.optim.optimizer import Optimizer
from lerobot.common.datasets.factory import make_dataset
@@ -55,18 +53,17 @@ from lerobot.common.utils.utils import (
set_global_random_state,
set_global_seed,
)
from lerobot.scripts.server import learner_service
from lerobot.scripts.server.buffer import (
ReplayBuffer,
concatenate_batch_transitions,
move_transition_to_device,
move_state_dict_to_device,
bytes_to_transitions,
state_to_bytes,
bytes_to_python_object,
bytes_to_transitions,
concatenate_batch_transitions,
move_state_dict_to_device,
move_transition_to_device,
state_to_bytes,
)
from lerobot.scripts.server import learner_service
from lerobot.scripts.server.utils import setup_process_handlers
def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
@@ -81,13 +78,9 @@ def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
# if resume == True
checkpoint_dir = Logger.get_last_checkpoint_dir(out_dir)
if not checkpoint_dir.exists():
raise RuntimeError(
f"No model checkpoint found in {checkpoint_dir} for resume=True"
)
raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True")
checkpoint_cfg_path = str(
Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml"
)
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
logging.info(
colored(
"Resume=True detected, resuming previous run",
@@ -136,9 +129,7 @@ def load_training_state(
def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
num_learnable_params = sum(
p.numel() for p in policy.parameters() if p.requires_grad
)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
log_output_dir(out_dir)
@@ -210,22 +201,15 @@ def initialize_offline_replay_buffer(
def get_observation_features(
policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
if (
policy.config.vision_encoder_name is None
or not policy.config.freeze_vision_encoder
):
if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder:
return None, None
with torch.no_grad():
observation_features = (
policy.actor.encoder(observations)
if policy.actor.encoder is not None
else None
policy.actor.encoder(observations) if policy.actor.encoder is not None else None
)
next_observation_features = (
policy.actor.encoder(next_observations)
if policy.actor.encoder is not None
else None
policy.actor.encoder(next_observations) if policy.actor.encoder is not None else None
)
return observation_features, next_observation_features
@@ -452,9 +436,7 @@ def add_actor_information_and_train(
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# Hack: But if we do online traning, we do not need dataset_stats
dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir)
if cfg.resume
else None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
)
# Update the policy config with the grad_clip_norm value from training config if it exists
@@ -469,9 +451,7 @@ def add_actor_information_and_train(
last_time_policy_pushed = time.time()
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
resume_optimization_step, resume_interaction_step = load_training_state(
cfg, logger, optimizers
)
resume_optimization_step, resume_interaction_step = load_training_state(cfg, logger, optimizers)
log_training_info(cfg, out_dir, policy)
@@ -483,9 +463,7 @@ def add_actor_information_and_train(
active_action_dims = None
if cfg.env.wrapper.joint_masking_action_space is not None:
active_action_dims = [
i
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
if mask
i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask
]
offline_replay_buffer = initialize_offline_replay_buffer(
cfg=cfg,
@@ -502,12 +480,8 @@ def add_actor_information_and_train(
time.time()
logging.info("Starting learner thread")
interaction_message, transition = None, None
optimization_step = (
resume_optimization_step if resume_optimization_step is not None else 0
)
interaction_step_shift = (
resume_interaction_step if resume_interaction_step is not None else 0
)
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
# Extract variables from cfg
online_step_before_learning = cfg.training.online_step_before_learning
@@ -519,9 +493,7 @@ def add_actor_information_and_train(
device = cfg.device
storage_device = cfg.training.storage_device
policy_update_freq = cfg.training.policy_update_freq
policy_parameters_push_frequency = (
cfg.actor_learner_config.policy_parameters_push_frequency
)
policy_parameters_push_frequency = cfg.actor_learner_config.policy_parameters_push_frequency
save_checkpoint = cfg.training.save_checkpoint
online_steps = cfg.training.online_steps
@@ -544,9 +516,9 @@ def add_actor_information_and_train(
continue
replay_buffer.add(**transition)
if cfg.dataset_repo_id is not None and transition.get(
"complementary_info", {}
).get("is_intervention"):
if cfg.dataset_repo_id is not None and transition.get("complementary_info", {}).get(
"is_intervention"
):
offline_replay_buffer.add(**transition)
logging.debug("[LEARNER] Received transitions")
@@ -556,9 +528,7 @@ def add_actor_information_and_train(
interaction_message = bytes_to_python_object(interaction_message)
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
interaction_message["Interaction step"] += interaction_step_shift
logger.log_dict(
interaction_message, mode="train", custom_step_key="Interaction step"
)
logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step")
logging.debug("[LEARNER] Received interactions")
@@ -579,9 +549,7 @@ def add_actor_information_and_train(
observations = batch["state"]
next_observations = batch["next_state"]
done = batch["done"]
check_nan_in_transition(
observations=observations, actions=actions, next_state=next_observations
)
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
observation_features, next_observation_features = get_observation_features(
policy, observations, next_observations
@@ -619,9 +587,7 @@ def add_actor_information_and_train(
next_observations = batch["next_state"]
done = batch["done"]
check_nan_in_transition(
observations=observations, actions=actions, next_state=next_observations
)
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
observation_features, next_observation_features = get_observation_features(
policy, observations, next_observations
@@ -697,23 +663,15 @@ def add_actor_information_and_train(
if optimization_step % log_freq == 0:
training_infos["replay_buffer_size"] = len(replay_buffer)
if offline_replay_buffer is not None:
training_infos["offline_replay_buffer_size"] = len(
offline_replay_buffer
)
training_infos["offline_replay_buffer_size"] = len(offline_replay_buffer)
training_infos["Optimization step"] = optimization_step
logger.log_dict(
d=training_infos, mode="train", custom_step_key="Optimization step"
)
logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step")
# logging.info(f"Training infos: {training_infos}")
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
frequency_for_one_optimization_step = 1 / (
time_for_one_optimization_step + 1e-9
)
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
logging.info(
f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}"
)
logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
logger.log_dict(
{
@@ -728,16 +686,12 @@ def add_actor_information_and_train(
if optimization_step % log_freq == 0:
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
if save_checkpoint and (
optimization_step % save_freq == 0 or optimization_step == online_steps
):
if save_checkpoint and (optimization_step % save_freq == 0 or optimization_step == online_steps):
logging.info(f"Checkpoint policy after step {optimization_step}")
_num_digits = max(6, len(str(online_steps)))
step_identifier = f"{optimization_step:0{_num_digits}d}"
interaction_step = (
interaction_message["Interaction step"]
if interaction_message is not None
else 0
interaction_message["Interaction step"] if interaction_message is not None else 0
)
logger.save_checkpoint(
optimization_step,
@@ -755,9 +709,7 @@ def add_actor_information_and_train(
shutil.rmtree(
dataset_dir,
)
replay_buffer.to_lerobot_dataset(
dataset_repo_id, fps=fps, root=logger.log_dir / "dataset"
)
replay_buffer.to_lerobot_dataset(dataset_repo_id, fps=fps, root=logger.log_dir / "dataset")
if offline_replay_buffer is not None:
dataset_dir = logger.log_dir / "dataset_offline"
@@ -809,9 +761,7 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
optimizer_critic = torch.optim.Adam(
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
)
optimizer_temperature = torch.optim.Adam(
params=[policy.log_alpha], lr=policy.config.critic_lr
)
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
lr_scheduler = None
optimizers = {
"actor": optimizer_actor,

View File

@@ -1,10 +1,10 @@
import hilserl_pb2 # type: ignore
import hilserl_pb2_grpc # type: ignore
import logging
from multiprocessing import Event, Queue
from lerobot.scripts.server.network_utils import receive_bytes_in_chunks
from lerobot.scripts.server.network_utils import send_bytes_in_chunks
import hilserl_pb2 # type: ignore
import hilserl_pb2_grpc # type: ignore
from lerobot.scripts.server.network_utils import receive_bytes_in_chunks, send_bytes_in_chunks
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
@@ -64,9 +64,7 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
def SendInteractions(self, request_iterator, _context):
# TODO: authorize the request
logging.info(
"[LEARNER] Received request to receive interactions from the Actor"
)
logging.info("[LEARNER] Received request to receive interactions from the Actor")
receive_bytes_in_chunks(
request_iterator,

View File

@@ -1,12 +1,12 @@
import einops
import numpy as np
import gymnasium as gym
import torch
from omegaconf import DictConfig
from typing import Any
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
import einops
import gymnasium as gym
import numpy as np
import torch
from mani_skill.utils.wrappers.record import RecordEpisode
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
from omegaconf import DictConfig
def preprocess_maniskill_observation(
@@ -63,9 +63,7 @@ class ManiSkillCompat(gym.Wrapper):
new_action_space_shape = env.action_space.shape[-1]
new_low = np.squeeze(env.action_space.low, axis=0)
new_high = np.squeeze(env.action_space.high, axis=0)
self.action_space = gym.spaces.Box(
low=new_low, high=new_high, shape=(new_action_space_shape,)
)
self.action_space = gym.spaces.Box(low=new_low, high=new_high, shape=(new_action_space_shape,))
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
@@ -84,9 +82,7 @@ class ManiSkillCompat(gym.Wrapper):
class ManiSkillActionWrapper(gym.ActionWrapper):
def __init__(self, env):
super().__init__(env)
self.action_space = gym.spaces.Tuple(
spaces=(env.action_space, gym.spaces.Discrete(2))
)
self.action_space = gym.spaces.Tuple(spaces=(env.action_space, gym.spaces.Discrete(2)))
def action(self, action):
action, telop = action
@@ -100,9 +96,7 @@ class ManiSkillMultiplyActionWrapper(gym.Wrapper):
action_space_agent: gym.spaces.Box = env.action_space[0]
action_space_agent.low = action_space_agent.low * multiply_factor
action_space_agent.high = action_space_agent.high * multiply_factor
self.action_space = gym.spaces.Tuple(
spaces=(action_space_agent, gym.spaces.Discrete(2))
)
self.action_space = gym.spaces.Tuple(spaces=(action_space_agent, gym.spaces.Discrete(2)))
def step(self, action):
if isinstance(action, tuple):
@@ -153,9 +147,7 @@ def make_maniskill(
)
env = ManiSkillObservationWrapper(env, device=cfg.env.device)
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
env._max_episode_steps = env.max_episode_steps = (
50 # gym_utils.find_max_episode_steps_value(env)
)
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
env.unwrapped.metadata["render_fps"] = 20
env = ManiSkillCompat(env)
env = ManiSkillActionWrapper(env)
@@ -166,12 +158,11 @@ def make_maniskill(
if __name__ == "__main__":
import argparse
import hydra
parser = argparse.ArgumentParser()
parser.add_argument(
"--config", type=str, default="lerobot/configs/env/maniskill_example.yaml"
)
parser.add_argument("--config", type=str, default="lerobot/configs/env/maniskill_example.yaml")
args = parser.parse_args()
# Initialize config

View File

@@ -15,12 +15,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.scripts.server import hilserl_pb2
import logging
import io
from multiprocessing import Queue, Event
import logging
from multiprocessing import Event, Queue
from typing import Any
from lerobot.scripts.server import hilserl_pb2
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
@@ -31,9 +32,7 @@ def bytes_buffer_size(buffer: io.BytesIO) -> int:
return result
def send_bytes_in_chunks(
buffer: bytes, message_class: Any, log_prefix: str = "", silent: bool = True
):
def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = "", silent: bool = True):
buffer = io.BytesIO(buffer)
size_in_bytes = bytes_buffer_size(buffer)
@@ -56,16 +55,12 @@ def send_bytes_in_chunks(
yield message_class(transfer_state=transfer_state, data=chunk)
sent_bytes += size_to_read
logging_method(
f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}"
)
logging_method(f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}")
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
def receive_bytes_in_chunks(
iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""
):
def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""):
bytes_buffer = io.BytesIO()
step = 0
@@ -89,9 +84,7 @@ def receive_bytes_in_chunks(
logging.debug(f"{log_prefix} Received data at step {step}")
elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_END:
bytes_buffer.write(item.data)
logging.debug(
f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}"
)
logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")
queue.put(bytes_buffer.getvalue())

View File

@@ -18,9 +18,10 @@
import logging
import signal
import sys
from torch.multiprocessing import Queue
from queue import Empty
from torch.multiprocessing import Queue
shutdown_event_counter = 0

View File

@@ -223,18 +223,12 @@ def train(cfg: TrainPipelineConfig):
step = 0 # number of policy updates (forward + backward + optim)
if cfg.resume:
step, optimizer, lr_scheduler = load_training_state(
cfg.checkpoint_path, optimizer, lr_scheduler
)
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
num_learnable_params = sum(
p.numel() for p in policy.parameters() if p.requires_grad
)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
logging.info(
colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}"
)
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
if cfg.env is not None:
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
@@ -335,9 +329,7 @@ def train(cfg: TrainPipelineConfig):
logging.info(f"Eval policy at step {step}")
with (
torch.no_grad(),
torch.autocast(device_type=device.type)
if cfg.policy.use_amp
else nullcontext(),
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
):
eval_info = eval_policy(
eval_env,

View File

@@ -52,19 +52,13 @@ def get_model(cfg, logger): # noqa I001
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
model = Classifier(classifier_config)
if cfg.resume:
model.load_state_dict(
Classifier.from_pretrained(
str(logger.last_pretrained_model_dir)
).state_dict()
)
model.load_state_dict(Classifier.from_pretrained(str(logger.last_pretrained_model_dir)).state_dict())
return model
def create_balanced_sampler(dataset, cfg):
# Get underlying dataset if using Subset
original_dataset = (
dataset.dataset if isinstance(dataset, torch.utils.data.Subset) else dataset
)
original_dataset = dataset.dataset if isinstance(dataset, torch.utils.data.Subset) else dataset
# Get indices if using Subset (for slicing)
indices = dataset.indices if isinstance(dataset, torch.utils.data.Subset) else None
@@ -83,9 +77,7 @@ def create_balanced_sampler(dataset, cfg):
class_weights = 1.0 / counts.float()
sample_weights = class_weights[labels]
return WeightedRandomSampler(
weights=sample_weights, num_samples=len(sample_weights), replacement=True
)
return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
def support_amp(device: torch.device, cfg: DictConfig) -> bool:
@@ -94,9 +86,7 @@ def support_amp(device: torch.device, cfg: DictConfig) -> bool:
return cfg.training.use_amp and device.type in ("cuda", "cpu")
def train_epoch(
model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg
):
def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg):
# Single epoch training loop with AMP support and progress tracking
model.train()
correct = 0
@@ -110,11 +100,7 @@ def train_epoch(
labels = batch[cfg.training.label_key].float().to(device)
# Forward pass with optional AMP
with (
torch.autocast(device_type=device.type)
if support_amp(device, cfg)
else nullcontext()
):
with torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext():
outputs = model(images)
loss = criterion(outputs.logits, labels)
@@ -159,9 +145,7 @@ def validate(model, val_loader, criterion, device, logger, cfg):
with (
torch.no_grad(),
torch.autocast(device_type=device.type)
if support_amp(device, cfg)
else nullcontext(),
torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(),
):
for batch in tqdm(val_loader, desc="Validation"):
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
@@ -174,9 +158,7 @@ def validate(model, val_loader, criterion, device, logger, cfg):
):
outputs = model(images)
inference_times.append(
next(
x for x in prof.key_averages() if x.key == "model_inference"
).cpu_time
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
)
else:
outputs = model(images)
@@ -194,24 +176,16 @@ def validate(model, val_loader, criterion, device, logger, cfg):
# Log sample predictions for visualization
if len(samples) < cfg.eval.num_samples_to_log:
for i in range(
min(cfg.eval.num_samples_to_log - len(samples), len(images))
):
for i in range(min(cfg.eval.num_samples_to_log - len(samples), len(images))):
if model.config.num_classes == 2:
confidence = round(outputs.probabilities[i].item(), 3)
else:
confidence = [
round(prob, 3) for prob in outputs.probabilities[i].tolist()
]
confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()]
samples.append(
{
**{
f"image_{img_key}": wandb.Image(
images[img_idx][i].cpu()
)
for img_idx, img_key in enumerate(
cfg.training.image_keys
)
f"image_{img_key}": wandb.Image(images[img_idx][i].cpu())
for img_idx, img_key in enumerate(cfg.training.image_keys)
},
"true_label": labels[i].item(),
"predicted": predictions[i].item(),
@@ -286,9 +260,7 @@ def benchmark_inference_time(model, dataset, logger, cfg, device, step):
_ = model(x)
inference_times.append(
next(
x for x in prof.key_averages() if x.key == "model_inference"
).cpu_time
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
)
inference_times = np.array(inference_times)
@@ -314,9 +286,7 @@ def benchmark_inference_time(model, dataset, logger, cfg, device, step):
return avg, median, std
def train(
cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None
) -> None:
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None) -> None:
if out_dir is None:
raise NotImplementedError()
if job_name is None:
@@ -372,9 +342,7 @@ def train(
"You have set resume=True, but there is no model checkpoint in "
f"{Logger.get_last_checkpoint_dir(out_dir)}"
)
checkpoint_cfg_path = str(
Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml"
)
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
logging.info(
colored(
"You have set resume=True, indicating that you wish to resume a run",
@@ -387,9 +355,7 @@ def train(
# Check for differences between the checkpoint configuration and provided configuration.
# Hack to resolve the delta_timestamps ahead of time in order to properly diff.
resolve_delta_timestamps(cfg)
diff = DeepDiff(
OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg)
)
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
# Ignore the `resume` and parameters.
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
del diff["values_changed"]["root['resume']"]
@@ -408,11 +374,7 @@ def train(
optimizer = optim.AdamW(model.parameters(), lr=cfg.training.learning_rate)
# Use BCEWithLogitsLoss for binary classification and CrossEntropyLoss for multi-class
criterion = (
nn.BCEWithLogitsLoss()
if model.config.num_classes == 2
else nn.CrossEntropyLoss()
)
criterion = nn.BCEWithLogitsLoss() if model.config.num_classes == 2 else nn.CrossEntropyLoss()
grad_scaler = GradScaler(enabled=cfg.training.use_amp)
# Log model parameters

View File

@@ -52,9 +52,7 @@ def make_optimizers_and_scheduler(cfg, policy):
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
)
# We wrap policy log temperature in list because this is a torch tensor and not a nn.Module
optimizer_temperature = torch.optim.Adam(
params=[policy.log_alpha], lr=policy.config.critic_lr
)
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
lr_scheduler = None
optimizers = {
"actor": optimizer_actor,
@@ -106,9 +104,7 @@ def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Te
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
# Gather pixels
cropped_hwcn = images_hwcn[
torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :
]
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
# cropped_hwcn => (B, crop_h, crop_w, C)
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
@@ -198,12 +194,8 @@ class ReplayBuffer:
"""
# We convert the LeRobotDataset into a replay buffer, because it is more efficient to sample from
# a replay buffer than from a lerobot dataset.
replay_buffer = cls(
capacity=len(lerobot_dataset), device=device, state_keys=state_keys
)
list_transition = cls._lerobotdataset_to_transitions(
dataset=lerobot_dataset, state_keys=state_keys
)
replay_buffer = cls(capacity=len(lerobot_dataset), device=device, state_keys=state_keys)
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
# Fill the replay buffer with the lerobot dataset transitions
for data in list_transition:
replay_buffer.add(
@@ -248,9 +240,7 @@ class ReplayBuffer:
# If not provided, you can either raise an error or define a default:
if state_keys is None:
raise ValueError(
"You must provide a list of keys in `state_keys` that define your 'state'."
)
raise ValueError("You must provide a list of keys in `state_keys` that define your 'state'.")
transitions: list[Transition] = []
num_frames = len(dataset)
@@ -304,40 +294,36 @@ class ReplayBuffer:
# -- Build batched states --
batch_state = {}
for key in self.state_keys:
batch_state[key] = torch.cat(
[t["state"][key] for t in list_of_transitions], dim=0
).to(self.device)
batch_state[key] = torch.cat([t["state"][key] for t in list_of_transitions], dim=0).to(
self.device
)
if key.startswith("observation.image") and self.use_drq:
batch_state[key] = self.image_augmentation_function(batch_state[key])
# -- Build batched actions --
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(
self.device
)
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(self.device)
# -- Build batched rewards --
batch_rewards = torch.tensor(
[t["reward"] for t in list_of_transitions], dtype=torch.float32
).to(self.device)
batch_rewards = torch.tensor([t["reward"] for t in list_of_transitions], dtype=torch.float32).to(
self.device
)
# -- Build batched next states --
batch_next_state = {}
for key in self.state_keys:
batch_next_state[key] = torch.cat(
[t["next_state"][key] for t in list_of_transitions], dim=0
).to(self.device)
batch_next_state[key] = torch.cat([t["next_state"][key] for t in list_of_transitions], dim=0).to(
self.device
)
if key.startswith("observation.image") and self.use_drq:
batch_next_state[key] = self.image_augmentation_function(
batch_next_state[key]
)
batch_next_state[key] = self.image_augmentation_function(batch_next_state[key])
# -- Build batched dones --
batch_dones = torch.tensor(
[t["done"] for t in list_of_transitions], dtype=torch.float32
).to(self.device)
batch_dones = torch.tensor(
[t["done"] for t in list_of_transitions], dtype=torch.float32
).to(self.device)
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
self.device
)
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
self.device
)
# Return a BatchTransition typed dict
return BatchTransition(
@@ -427,9 +413,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# Hack: But if we do online traning, we do not need dataset_stats
dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir)
if cfg.resume
else None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
device=device,
)
assert isinstance(policy, nn.Module)
@@ -438,9 +422,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# TODO: Handle resume
num_learnable_params = sum(
p.numel() for p in policy.parameters() if p.requires_grad
)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
log_output_dir(out_dir)
@@ -481,16 +463,12 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
if interaction_step >= cfg.training.online_step_before_learning:
action = policy.select_action(batch=obs)
next_obs, reward, done, truncated, info = online_env.step(
action.cpu().numpy()
)
next_obs, reward, done, truncated, info = online_env.step(action.cpu().numpy())
else:
action = online_env.action_space.sample()
next_obs, reward, done, truncated, info = online_env.step(action)
# HACK
action = torch.tensor(action, dtype=torch.float32).to(
device, non_blocking=True
)
action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True)
# HACK: For maniskill
# next_obs = preprocess_observation(next_obs)
@@ -500,20 +478,14 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# Because we are using a single environment
# we can safely assume that the episode is done
if done[0] or truncated[0]:
logging.info(
f"Global step {interaction_step}: Episode reward: {sum_reward_episode}"
)
logger.log_dict(
{"Sum episode reward": sum_reward_episode}, interaction_step
)
logging.info(f"Global step {interaction_step}: Episode reward: {sum_reward_episode}")
logger.log_dict({"Sum episode reward": sum_reward_episode}, interaction_step)
sum_reward_episode = 0
# HACK: This is for maniskill
logging.info(
f"global step {interaction_step}: episode success: {info['success'].float().item()} \n"
)
logger.log_dict(
{"Episode success": info["success"].float().item()}, interaction_step
)
logger.log_dict({"Episode success": info["success"].float().item()}, interaction_step)
replay_buffer.add(
state=obs,
@@ -587,9 +559,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
training_infos["loss_actor"] = loss_actor.item()
loss_temperature = policy.compute_loss_temperature(
observations=observations
)
loss_temperature = policy.compute_loss_temperature(observations=observations)
optimizers["temperature"].zero_grad()
loss_temperature.backward()
optimizers["temperature"].step()
@@ -611,9 +581,7 @@ def train_cli(cfg: dict):
)
def train_notebook(
out_dir=None, job_name=None, config_name="default", config_path="../configs"
):
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
from hydra import compose, initialize
hydra.core.global_hydra.GlobalHydra.instance().clear()

View File

@@ -94,12 +94,8 @@ def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
assert chw_float32_torch.dtype == torch.float32
assert chw_float32_torch.ndim == 3
c, h, w = chw_float32_torch.shape
assert c < h and c < w, (
f"expect channel first images, but instead {chw_float32_torch.shape}"
)
hwc_uint8_numpy = (
(chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
)
assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}"
hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
return hwc_uint8_numpy

View File

@@ -142,12 +142,8 @@ def run_server(
)
)
@app.route(
"/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>"
)
def show_episode(
dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes
):
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes):
repo_id = f"{dataset_namespace}/{dataset_name}"
try:
if dataset is None:
@@ -158,9 +154,7 @@ def run_server(
400,
)
dataset_version = (
str(dataset.meta._version)
if isinstance(dataset, LeRobotDataset)
else dataset.codebase_version
str(dataset.meta._version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
)
match = re.search(r"v(\d+)\.", dataset_version)
if match:
@@ -168,9 +162,7 @@ def run_server(
if major_version < 2:
return "Make sure to convert your LeRobotDataset to v2 & above."
episode_data_csv_str, columns, ignored_columns = get_episode_data(
dataset, episode_id
)
episode_data_csv_str, columns, ignored_columns = get_episode_data(dataset, episode_id)
dataset_info = {
"repo_id": f"{dataset_namespace}/{dataset_name}",
"num_samples": dataset.num_frames
@@ -183,8 +175,7 @@ def run_server(
}
if isinstance(dataset, LeRobotDataset):
video_paths = [
dataset.meta.get_video_file_path(episode_id, key)
for key in dataset.meta.video_keys
dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys
]
videos_info = [
{
@@ -197,9 +188,7 @@ def run_server(
]
tasks = dataset.meta.episodes[episode_id]["tasks"]
else:
video_keys = [
key for key, ft in dataset.features.items() if ft["dtype"] == "video"
]
video_keys = [key for key, ft in dataset.features.items() if ft["dtype"] == "video"]
videos_info = [
{
"url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
@@ -219,24 +208,16 @@ def run_server(
)
response.raise_for_status()
# Split into lines and parse each line as JSON
tasks_jsonl = [
json.loads(line) for line in response.text.splitlines() if line.strip()
]
tasks_jsonl = [json.loads(line) for line in response.text.splitlines() if line.strip()]
filtered_tasks_jsonl = [
row for row in tasks_jsonl if row["episode_index"] == episode_id
]
filtered_tasks_jsonl = [row for row in tasks_jsonl if row["episode_index"] == episode_id]
tasks = filtered_tasks_jsonl[0]["tasks"]
videos_info[0]["language_instruction"] = tasks
if episodes is None:
episodes = list(
range(
dataset.num_episodes
if isinstance(dataset, LeRobotDataset)
else dataset.total_episodes
)
range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes)
)
return render_template(
@@ -263,11 +244,7 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
This file will be loaded by Dygraph javascript to plot data in real time."""
columns = []
selected_columns = [
col
for col, ft in dataset.features.items()
if ft["dtype"] in ["float32", "int32"]
]
selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] in ["float32", "int32"]]
selected_columns.remove("timestamp")
ignored_columns = []
@@ -288,10 +265,7 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
else dataset.features[column_name].shape[0]
)
if (
"names" in dataset.features[column_name]
and dataset.features[column_name]["names"]
):
if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]:
column_names = dataset.features[column_name]["names"]
while not isinstance(column_names, list):
column_names = list(column_names.values())[0]
@@ -314,12 +288,9 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
else:
repo_id = dataset.repo_id
url = (
f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
+ dataset.data_path.format(
episode_chunk=int(episode_index) // dataset.chunks_size,
episode_index=episode_index,
)
url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format(
episode_chunk=int(episode_index) // dataset.chunks_size,
episode_index=episode_index,
)
df = pd.read_parquet(url)
data = df[selected_columns] # Select specific columns
@@ -352,9 +323,7 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
]
def get_episode_language_instruction(
dataset: LeRobotDataset, ep_index: int
) -> list[str]:
def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]:
# check if the dataset has language instructions
if "language_instruction" not in dataset.features:
return None
@@ -365,9 +334,7 @@ def get_episode_language_instruction(
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
# with the tf.tensor appearing in the string
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix(
"', shape=(), dtype=string)"
)
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)")
def get_dataset_info(repo_id: str) -> IterableNamespace:
@@ -403,9 +370,7 @@ def visualize_dataset_html(
if force_override:
shutil.rmtree(output_dir)
else:
logging.info(
f"Output directory already exists. Loading from it: '{output_dir}'"
)
logging.info(f"Output directory already exists. Loading from it: '{output_dir}'")
output_dir.mkdir(parents=True, exist_ok=True)

View File

@@ -47,9 +47,7 @@ OUTPUT_DIR = Path("outputs/image_transforms")
to_pil = ToPILImage()
def save_all_transforms(
cfg: ImageTransformsConfig, original_frame, output_dir, n_examples
):
def save_all_transforms(cfg: ImageTransformsConfig, original_frame, output_dir, n_examples):
output_dir_all = output_dir / "all"
output_dir_all.mkdir(parents=True, exist_ok=True)
@@ -62,9 +60,7 @@ def save_all_transforms(
print(f" {output_dir_all}")
def save_each_transform(
cfg: ImageTransformsConfig, original_frame, output_dir, n_examples
):
def save_each_transform(cfg: ImageTransformsConfig, original_frame, output_dir, n_examples):
if not cfg.enable:
logging.warning(
"No single transforms will be saved, because `image_transforms.enable=False`. To enable, set `enable` to True in `ImageTransformsConfig` or in the command line with `--image_transforms.enable=True`."
@@ -93,15 +89,9 @@ def save_each_transform(
tf_cfg_kwgs_max[key] = [max_, max_]
tf_cfg_kwgs_avg[key] = [avg, avg]
tf_min = make_transform_from_config(
replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_min})
)
tf_max = make_transform_from_config(
replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_max})
)
tf_avg = make_transform_from_config(
replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_avg})
)
tf_min = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_min}))
tf_max = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_max}))
tf_avg = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_avg}))
tf_frame_min = tf_min(original_frame)
tf_frame_max = tf_max(original_frame)
@@ -115,9 +105,7 @@ def save_each_transform(
@draccus.wrap()
def visualize_image_transforms(
cfg: DatasetConfig, output_dir: Path = OUTPUT_DIR, n_examples: int = 5
):
def visualize_image_transforms(cfg: DatasetConfig, output_dir: Path = OUTPUT_DIR, n_examples: int = 5):
dataset = LeRobotDataset(
repo_id=cfg.repo_id,
episodes=cfg.episodes,