[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Michel Aractingi
parent
2abbd60a0d
commit
0ea27704f6
@@ -67,8 +67,8 @@ def get_motor_bus_cls(brand: str) -> tuple:
|
||||
|
||||
|
||||
def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
|
||||
motor_bus_config_cls, motor_bus_cls, model_baudrate_table, series_baudrate_table = (
|
||||
get_motor_bus_cls(brand)
|
||||
motor_bus_config_cls, motor_bus_cls, model_baudrate_table, series_baudrate_table = get_motor_bus_cls(
|
||||
brand
|
||||
)
|
||||
|
||||
# Check if the provided model exists in the model_baud_rate_table
|
||||
@@ -82,9 +82,7 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
|
||||
motor_index_arbitrary = motor_idx_des # Use the motor ID passed via argument
|
||||
motor_model = model # Use the motor model passed via argument
|
||||
|
||||
config = motor_bus_config_cls(
|
||||
port=port, motors={motor_name: (motor_index_arbitrary, motor_model)}
|
||||
)
|
||||
config = motor_bus_config_cls(port=port, motors={motor_name: (motor_index_arbitrary, motor_model)})
|
||||
|
||||
# Initialize the MotorBus with the correct port and motor configurations
|
||||
motor_bus = motor_bus_cls(config=config)
|
||||
@@ -120,26 +118,20 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
|
||||
break
|
||||
|
||||
if motor_index == -1:
|
||||
raise ValueError(
|
||||
"No motors detected. Please ensure you have one motor connected."
|
||||
)
|
||||
raise ValueError("No motors detected. Please ensure you have one motor connected.")
|
||||
|
||||
print(f"Motor index found at: {motor_index}")
|
||||
|
||||
if brand == "feetech":
|
||||
# Allows ID and BAUDRATE to be written in memory
|
||||
motor_bus.write_with_motor_ids(
|
||||
motor_bus.motor_models, motor_index, "Lock", 0
|
||||
)
|
||||
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0)
|
||||
|
||||
if baudrate != baudrate_des:
|
||||
print(f"Setting its baudrate to {baudrate_des}")
|
||||
baudrate_idx = list(series_baudrate_table.values()).index(baudrate_des)
|
||||
|
||||
# The write can fail, so we allow retries
|
||||
motor_bus.write_with_motor_ids(
|
||||
motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx
|
||||
)
|
||||
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx)
|
||||
time.sleep(0.5)
|
||||
motor_bus.set_bus_baudrate(baudrate_des)
|
||||
present_baudrate_idx = motor_bus.read_with_motor_ids(
|
||||
@@ -151,16 +143,10 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
|
||||
|
||||
print(f"Setting its index to desired index {motor_idx_des}")
|
||||
if brand == "feetech":
|
||||
motor_bus.write_with_motor_ids(
|
||||
motor_bus.motor_models, motor_index, "Lock", 0
|
||||
)
|
||||
motor_bus.write_with_motor_ids(
|
||||
motor_bus.motor_models, motor_index, "ID", motor_idx_des
|
||||
)
|
||||
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0)
|
||||
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "ID", motor_idx_des)
|
||||
|
||||
present_idx = motor_bus.read_with_motor_ids(
|
||||
motor_bus.motor_models, motor_idx_des, "ID", num_retry=2
|
||||
)
|
||||
present_idx = motor_bus.read_with_motor_ids(motor_bus.motor_models, motor_idx_des, "ID", num_retry=2)
|
||||
if present_idx != motor_idx_des:
|
||||
raise OSError("Failed to write index.")
|
||||
|
||||
@@ -194,12 +180,8 @@ if __name__ == "__main__":
|
||||
required=True,
|
||||
help="Motors bus port (e.g. dynamixel,feetech)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)"
|
||||
)
|
||||
parser.add_argument("--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)")
|
||||
parser.add_argument("--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)")
|
||||
parser.add_argument(
|
||||
"--ID",
|
||||
type=int,
|
||||
|
||||
@@ -255,8 +255,7 @@ def record(
|
||||
if len(robot.cameras) > 0:
|
||||
dataset.start_image_writer(
|
||||
num_processes=cfg.num_image_writer_processes,
|
||||
num_threads=cfg.num_image_writer_threads_per_camera
|
||||
* len(robot.cameras),
|
||||
num_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
)
|
||||
sanity_check_dataset_robot_compatibility(dataset, robot, cfg.fps, cfg.video)
|
||||
else:
|
||||
@@ -269,19 +268,14 @@ def record(
|
||||
robot=robot,
|
||||
use_videos=cfg.video,
|
||||
image_writer_processes=cfg.num_image_writer_processes,
|
||||
image_writer_threads=cfg.num_image_writer_threads_per_camera
|
||||
* len(robot.cameras),
|
||||
image_writer_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
)
|
||||
|
||||
# Load pretrained policy
|
||||
policy = (
|
||||
None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
)
|
||||
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
|
||||
# Load pretrained policy
|
||||
policy = (
|
||||
None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
)
|
||||
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
@@ -174,10 +174,7 @@ def teleoperate(env, robot: Robot, process_action_fn, teleop_time_s=None):
|
||||
leader_pos = robot.leader_arms.main.read("Present_Position")
|
||||
action = process_action_fn(leader_pos)
|
||||
env.step(np.expand_dims(action, 0))
|
||||
if (
|
||||
teleop_time_s is not None
|
||||
and time.perf_counter() - start_teleop_t > teleop_time_s
|
||||
):
|
||||
if teleop_time_s is not None and time.perf_counter() - start_teleop_t > teleop_time_s:
|
||||
print("Teleoperation processes finished.")
|
||||
break
|
||||
|
||||
@@ -209,27 +206,19 @@ def record(
|
||||
# Load pretrained policy
|
||||
|
||||
extra_features = (
|
||||
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}}
|
||||
if assign_rewards
|
||||
else None
|
||||
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None
|
||||
)
|
||||
|
||||
policy = None
|
||||
if pretrained_policy_name_or_path is not None:
|
||||
policy, policy_fps, device, use_amp = init_policy(
|
||||
pretrained_policy_name_or_path, policy_overrides
|
||||
)
|
||||
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
|
||||
|
||||
if fps is None:
|
||||
fps = policy_fps
|
||||
logging.warning(
|
||||
f"No fps provided, so using the fps from policy config ({policy_fps})."
|
||||
)
|
||||
logging.warning(f"No fps provided, so using the fps from policy config ({policy_fps}).")
|
||||
|
||||
if policy is None and process_action_from_leader is None:
|
||||
raise ValueError(
|
||||
"Either policy or process_action_fn has to be set to enable control in sim."
|
||||
)
|
||||
raise ValueError("Either policy or process_action_fn has to be set to enable control in sim.")
|
||||
|
||||
# initialize listener before sim env
|
||||
listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
|
||||
@@ -380,9 +369,7 @@ def record(
|
||||
if events["stop_recording"] or recorded_episodes >= num_episodes:
|
||||
break
|
||||
else:
|
||||
logging.info(
|
||||
"Waiting for a few seconds before starting next episode recording..."
|
||||
)
|
||||
logging.info("Waiting for a few seconds before starting next episode recording...")
|
||||
busy_wait(3)
|
||||
|
||||
log_say("Stop recording", play_sounds, blocking=True)
|
||||
@@ -481,9 +468,7 @@ if __name__ == "__main__":
|
||||
required=True,
|
||||
help="A description of the task preformed during recording that can be used as a language instruction.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--num-episodes", type=int, default=50, help="Number of episodes to record."
|
||||
)
|
||||
parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.")
|
||||
parser_record.add_argument(
|
||||
"--run-compute-stats",
|
||||
type=int,
|
||||
@@ -561,9 +546,7 @@ if __name__ == "__main__":
|
||||
default="lerobot/test",
|
||||
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
|
||||
)
|
||||
parser_replay.add_argument(
|
||||
"--episode", type=int, default=0, help="Index of the episodes to replay."
|
||||
)
|
||||
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episodes to replay.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
@@ -59,11 +59,7 @@ np_version = np.__version__ if HAS_NP else "N/A"
|
||||
|
||||
torch_version = torch.__version__ if HAS_TORCH else "N/A"
|
||||
torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A"
|
||||
cuda_version = (
|
||||
torch._C._cuda_getCompiledVersion()
|
||||
if HAS_TORCH and torch.version.cuda is not None
|
||||
else "N/A"
|
||||
)
|
||||
cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A"
|
||||
|
||||
|
||||
# TODO(aliberts): refactor into an actual command `lerobot env`
|
||||
@@ -81,9 +77,7 @@ def display_sys_info() -> dict:
|
||||
"Using GPU in script?": "<fill in>",
|
||||
# "Using distributed or parallel set-up in script?": "<fill in>",
|
||||
}
|
||||
print(
|
||||
"\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n"
|
||||
)
|
||||
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n")
|
||||
print(format_dict(info))
|
||||
return info
|
||||
|
||||
|
||||
@@ -152,8 +152,7 @@ def rollout(
|
||||
all_observations.append(deepcopy(observation))
|
||||
|
||||
observation = {
|
||||
key: observation[key].to(device, non_blocking=device.type == "cuda")
|
||||
for key in observation
|
||||
key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation
|
||||
}
|
||||
|
||||
# Infer "task" from attributes of environments.
|
||||
@@ -175,10 +174,7 @@ def rollout(
|
||||
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
|
||||
# available of none of the envs finished.
|
||||
if "final_info" in info:
|
||||
successes = [
|
||||
info["is_success"] if info is not None else False
|
||||
for info in info["final_info"]
|
||||
]
|
||||
successes = [info["is_success"] if info is not None else False for info in info["final_info"]]
|
||||
else:
|
||||
successes = [False] * env.num_envs
|
||||
|
||||
@@ -192,13 +188,9 @@ def rollout(
|
||||
|
||||
step += 1
|
||||
running_success_rate = (
|
||||
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any")
|
||||
.numpy()
|
||||
.mean()
|
||||
)
|
||||
progbar.set_postfix(
|
||||
{"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"}
|
||||
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean()
|
||||
)
|
||||
progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"})
|
||||
progbar.update()
|
||||
|
||||
# Track the final observation.
|
||||
@@ -216,9 +208,7 @@ def rollout(
|
||||
if return_observations:
|
||||
stacked_observations = {}
|
||||
for key in all_observations[0]:
|
||||
stacked_observations[key] = torch.stack(
|
||||
[obs[key] for obs in all_observations], dim=1
|
||||
)
|
||||
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
|
||||
ret["observation"] = stacked_observations
|
||||
|
||||
if hasattr(policy, "use_original_modules"):
|
||||
@@ -280,9 +270,7 @@ def eval_policy(
|
||||
return
|
||||
n_to_render_now = min(max_episodes_rendered - n_episodes_rendered, env.num_envs)
|
||||
if isinstance(env, gym.vector.SyncVectorEnv):
|
||||
ep_frames.append(
|
||||
np.stack([env.envs[i].render() for i in range(n_to_render_now)])
|
||||
) # noqa: B023
|
||||
ep_frames.append(np.stack([env.envs[i].render() for i in range(n_to_render_now)])) # noqa: B023
|
||||
elif isinstance(env, gym.vector.AsyncVectorEnv):
|
||||
# Here we must render all frames and discard any we don't need.
|
||||
ep_frames.append(np.stack(env.call("render")[:n_to_render_now]))
|
||||
@@ -294,9 +282,7 @@ def eval_policy(
|
||||
episode_data: dict | None = None
|
||||
|
||||
# we dont want progress bar when we use slurm, since it clutters the logs
|
||||
progbar = trange(
|
||||
n_batches, desc="Stepping through eval batches", disable=inside_slurm()
|
||||
)
|
||||
progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm())
|
||||
for batch_ix in progbar:
|
||||
# Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout
|
||||
# step.
|
||||
@@ -326,22 +312,13 @@ def eval_policy(
|
||||
|
||||
# Make a mask with shape (batch, n_steps) to mask out rollout data after the first done
|
||||
# (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step.
|
||||
mask = (
|
||||
torch.arange(n_steps)
|
||||
<= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)
|
||||
).int()
|
||||
mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int()
|
||||
# Extend metrics.
|
||||
batch_sum_rewards = einops.reduce(
|
||||
(rollout_data["reward"] * mask), "b n -> b", "sum"
|
||||
)
|
||||
batch_sum_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "sum")
|
||||
sum_rewards.extend(batch_sum_rewards.tolist())
|
||||
batch_max_rewards = einops.reduce(
|
||||
(rollout_data["reward"] * mask), "b n -> b", "max"
|
||||
)
|
||||
batch_max_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "max")
|
||||
max_rewards.extend(batch_max_rewards.tolist())
|
||||
batch_successes = einops.reduce(
|
||||
(rollout_data["success"] * mask), "b n -> b", "any"
|
||||
)
|
||||
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
|
||||
all_successes.extend(batch_successes.tolist())
|
||||
if seeds:
|
||||
all_seeds.extend(seeds)
|
||||
@@ -354,27 +331,17 @@ def eval_policy(
|
||||
rollout_data,
|
||||
done_indices,
|
||||
start_episode_index=batch_ix * env.num_envs,
|
||||
start_data_index=(
|
||||
0
|
||||
if episode_data is None
|
||||
else (episode_data["index"][-1].item() + 1)
|
||||
),
|
||||
start_data_index=(0 if episode_data is None else (episode_data["index"][-1].item() + 1)),
|
||||
fps=env.unwrapped.metadata["render_fps"],
|
||||
)
|
||||
if episode_data is None:
|
||||
episode_data = this_episode_data
|
||||
else:
|
||||
# Some sanity checks to make sure we are correctly compiling the data.
|
||||
assert (
|
||||
episode_data["episode_index"][-1] + 1
|
||||
== this_episode_data["episode_index"][0]
|
||||
)
|
||||
assert episode_data["episode_index"][-1] + 1 == this_episode_data["episode_index"][0]
|
||||
assert episode_data["index"][-1] + 1 == this_episode_data["index"][0]
|
||||
# Concatenate the episode data.
|
||||
episode_data = {
|
||||
k: torch.cat([episode_data[k], this_episode_data[k]])
|
||||
for k in episode_data
|
||||
}
|
||||
episode_data = {k: torch.cat([episode_data[k], this_episode_data[k]]) for k in episode_data}
|
||||
|
||||
# Maybe render video for visualization.
|
||||
if max_episodes_rendered > 0 and len(ep_frames) > 0:
|
||||
@@ -392,9 +359,7 @@ def eval_policy(
|
||||
target=write_video,
|
||||
args=(
|
||||
str(video_path),
|
||||
stacked_frames[
|
||||
: done_index + 1
|
||||
], # + 1 to capture the last observation
|
||||
stacked_frames[: done_index + 1], # + 1 to capture the last observation
|
||||
env.unwrapped.metadata["render_fps"],
|
||||
),
|
||||
)
|
||||
@@ -403,9 +368,7 @@ def eval_policy(
|
||||
n_episodes_rendered += 1
|
||||
|
||||
progbar.set_postfix(
|
||||
{
|
||||
"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"
|
||||
}
|
||||
{"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"}
|
||||
)
|
||||
|
||||
# Wait till all video rendering threads are done.
|
||||
@@ -473,16 +436,12 @@ def _compile_episode_data(
|
||||
# Here we do `num_frames - 1` as we don't want to include the last observation frame just yet.
|
||||
ep_dict = {
|
||||
"action": rollout_data["action"][ep_ix, : num_frames - 1],
|
||||
"episode_index": torch.tensor(
|
||||
[start_episode_index + ep_ix] * (num_frames - 1)
|
||||
),
|
||||
"episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)),
|
||||
"frame_index": torch.arange(0, num_frames - 1, 1),
|
||||
"timestamp": torch.arange(0, num_frames - 1, 1) / fps,
|
||||
"next.done": rollout_data["done"][ep_ix, : num_frames - 1],
|
||||
"next.success": rollout_data["success"][ep_ix, : num_frames - 1],
|
||||
"next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(
|
||||
torch.float32
|
||||
),
|
||||
"next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32),
|
||||
}
|
||||
|
||||
# For the last observation frame, all other keys will just be copy padded.
|
||||
@@ -498,9 +457,7 @@ def _compile_episode_data(
|
||||
for key in ep_dicts[0]:
|
||||
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
|
||||
data_dict["index"] = torch.arange(
|
||||
start_data_index, start_data_index + total_frames, 1
|
||||
)
|
||||
data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1)
|
||||
|
||||
return data_dict
|
||||
|
||||
@@ -516,14 +473,10 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
set_seed(cfg.seed)
|
||||
|
||||
logging.info(
|
||||
colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}"
|
||||
)
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
||||
|
||||
logging.info("Making environment.")
|
||||
env = make_env(
|
||||
cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs
|
||||
)
|
||||
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||
|
||||
logging.info("Making policy.")
|
||||
|
||||
@@ -535,9 +488,7 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
|
||||
with (
|
||||
torch.no_grad(),
|
||||
torch.autocast(device_type=device.type)
|
||||
if cfg.policy.use_amp
|
||||
else nullcontext(),
|
||||
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
|
||||
):
|
||||
info = eval_policy(
|
||||
env,
|
||||
|
||||
@@ -74,9 +74,7 @@ def get_classifier(pretrained_path, config_path):
|
||||
cfg = init_hydra_config(config_path)
|
||||
|
||||
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
||||
classifier_config.num_cameras = len(
|
||||
cfg.training.image_keys
|
||||
) # TODO automate these paths
|
||||
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
|
||||
model = Classifier(classifier_config)
|
||||
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
|
||||
model = model.to("mps")
|
||||
@@ -161,17 +159,11 @@ def rollout(
|
||||
images = []
|
||||
for key in image_keys:
|
||||
if display_cameras:
|
||||
cv2.imshow(
|
||||
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
|
||||
)
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
images.append(observation[key].to("mps"))
|
||||
|
||||
reward = (
|
||||
reward_classifier.predict_reward(images)
|
||||
if reward_classifier is not None
|
||||
else 0.0
|
||||
)
|
||||
reward = reward_classifier.predict_reward(images) if reward_classifier is not None else 0.0
|
||||
all_rewards.append(reward)
|
||||
|
||||
# print("REWARD : ", reward)
|
||||
@@ -235,9 +227,7 @@ def eval_policy(
|
||||
|
||||
start_eval = time.perf_counter()
|
||||
progbar = trange(n_episodes, desc="Evaluating policy on real robot")
|
||||
reward_classifier = get_classifier(
|
||||
reward_classifier_pretrained_path, reward_classifier_config_file
|
||||
)
|
||||
reward_classifier = get_classifier(reward_classifier_pretrained_path, reward_classifier_config_file)
|
||||
|
||||
for _ in progbar:
|
||||
rollout_data = rollout(
|
||||
@@ -313,9 +303,7 @@ def init_keyboard_listener():
|
||||
print("Right arrow key pressed. Exiting loop...")
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
print(
|
||||
"Left arrow key pressed. Exiting loop and rerecord the last episode..."
|
||||
)
|
||||
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.space:
|
||||
@@ -380,9 +368,7 @@ if __name__ == "__main__":
|
||||
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision", help="Optionally provide the Hugging Face Hub revision ID."
|
||||
)
|
||||
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
|
||||
parser.add_argument(
|
||||
"--out-dir",
|
||||
help=(
|
||||
|
||||
@@ -45,13 +45,9 @@ def find_port():
|
||||
print(f"The port of this MotorsBus is '{port}'")
|
||||
print("Reconnect the USB cable.")
|
||||
elif len(ports_diff) == 0:
|
||||
raise OSError(
|
||||
f"Could not detect the port. No difference was found ({ports_diff})."
|
||||
)
|
||||
raise OSError(f"Could not detect the port. No difference was found ({ports_diff}).")
|
||||
else:
|
||||
raise OSError(
|
||||
f"Could not detect the port. More than one port was found ({ports_diff})."
|
||||
)
|
||||
raise OSError(f"Could not detect the port. More than one port was found ({ports_diff}).")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -14,18 +14,18 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from statistics import mean, quantiles
|
||||
import time
|
||||
from functools import lru_cache
|
||||
from lerobot.scripts.server.utils import setup_process_handlers
|
||||
from queue import Empty
|
||||
from statistics import mean, quantiles
|
||||
|
||||
# from lerobot.scripts.eval import eval_policy
|
||||
|
||||
import grpc
|
||||
import hydra
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
from torch import nn
|
||||
import time
|
||||
from torch.multiprocessing import Event, Queue
|
||||
|
||||
# TODO: Remove the import of maniskill
|
||||
# from lerobot.common.envs.factory import make_maniskill_env
|
||||
@@ -34,34 +34,28 @@ from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
from lerobot.common.utils.utils import (
|
||||
TimerManager,
|
||||
get_safe_torch_device,
|
||||
init_logging,
|
||||
set_global_seed,
|
||||
)
|
||||
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc
|
||||
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc, learner_service
|
||||
from lerobot.scripts.server.buffer import (
|
||||
Transition,
|
||||
bytes_to_state_dict,
|
||||
move_state_dict_to_device,
|
||||
move_transition_to_device,
|
||||
python_object_to_bytes,
|
||||
transitions_to_bytes,
|
||||
bytes_to_state_dict,
|
||||
)
|
||||
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
|
||||
from lerobot.scripts.server.network_utils import (
|
||||
receive_bytes_in_chunks,
|
||||
send_bytes_in_chunks,
|
||||
)
|
||||
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
|
||||
from lerobot.scripts.server import learner_service
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
|
||||
from torch.multiprocessing import Queue, Event
|
||||
from queue import Empty
|
||||
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
|
||||
from lerobot.scripts.server.utils import get_last_item_from_queue
|
||||
from lerobot.scripts.server.utils import get_last_item_from_queue, setup_process_handlers
|
||||
|
||||
ACTOR_SHUTDOWN_TIMEOUT = 30
|
||||
|
||||
@@ -102,9 +96,7 @@ def receive_policy(
|
||||
logging.info("[ACTOR] Received policy loop stopped")
|
||||
|
||||
|
||||
def transitions_stream(
|
||||
shutdown_event: Event, transitions_queue: Queue
|
||||
) -> hilserl_pb2.Empty:
|
||||
def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilserl_pb2.Empty:
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
message = transitions_queue.get(block=True, timeout=5)
|
||||
@@ -169,9 +161,7 @@ def send_transitions(
|
||||
)
|
||||
|
||||
try:
|
||||
learner_client.SendTransitions(
|
||||
transitions_stream(shutdown_event, transitions_queue)
|
||||
)
|
||||
learner_client.SendTransitions(transitions_stream(shutdown_event, transitions_queue))
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||
|
||||
@@ -211,9 +201,7 @@ def send_interactions(
|
||||
)
|
||||
|
||||
try:
|
||||
learner_client.SendInteractions(
|
||||
interactions_stream(shutdown_event, interactions_queue)
|
||||
)
|
||||
learner_client.SendInteractions(interactions_stream(shutdown_event, interactions_queue))
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||
|
||||
@@ -301,9 +289,7 @@ def act_with_policy(
|
||||
|
||||
logging.info("make_env online")
|
||||
|
||||
online_env = make_robot_env(
|
||||
robot=robot, reward_classifier=reward_classifier, cfg=cfg
|
||||
)
|
||||
online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg)
|
||||
|
||||
set_global_seed(cfg.seed)
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
@@ -355,13 +341,9 @@ def act_with_policy(
|
||||
action = policy.select_action(batch=obs)
|
||||
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
|
||||
|
||||
log_policy_frequency_issue(
|
||||
policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step
|
||||
)
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
|
||||
next_obs, reward, done, truncated, info = online_env.step(
|
||||
action.squeeze(dim=0).cpu().numpy()
|
||||
)
|
||||
next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy())
|
||||
else:
|
||||
# TODO (azouitine): Make a custom space for torch tensor
|
||||
action = online_env.action_space.sample()
|
||||
@@ -369,9 +351,7 @@ def act_with_policy(
|
||||
|
||||
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
|
||||
action = (
|
||||
torch.from_numpy(action[0])
|
||||
.to(device, non_blocking=device.type == "cuda")
|
||||
.unsqueeze(dim=0)
|
||||
torch.from_numpy(action[0]).to(device, non_blocking=device.type == "cuda").unsqueeze(dim=0)
|
||||
)
|
||||
|
||||
sum_reward_episode += float(reward)
|
||||
@@ -391,9 +371,7 @@ def act_with_policy(
|
||||
# Check for NaN values in observations
|
||||
for key, tensor in obs.items():
|
||||
if torch.isnan(tensor).any():
|
||||
logging.error(
|
||||
f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}"
|
||||
)
|
||||
logging.error(f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}")
|
||||
|
||||
list_transition_to_send_to_learner.append(
|
||||
Transition(
|
||||
@@ -413,13 +391,9 @@ def act_with_policy(
|
||||
# Because we are using a single environment we can index at zero
|
||||
if done or truncated:
|
||||
# TODO: Handle logging for episode information
|
||||
logging.info(
|
||||
f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}"
|
||||
)
|
||||
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||
|
||||
update_policy_parameters(
|
||||
policy=policy.actor, parameters_queue=parameters_queue, device=device
|
||||
)
|
||||
update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device)
|
||||
|
||||
if len(list_transition_to_send_to_learner) > 0:
|
||||
push_transitions_to_transport_queue(
|
||||
@@ -495,9 +469,7 @@ def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
|
||||
return stats
|
||||
|
||||
|
||||
def log_policy_frequency_issue(
|
||||
policy_fps: float, cfg: DictConfig, interaction_step: int
|
||||
):
|
||||
def log_policy_frequency_issue(policy_fps: float, cfg: DictConfig, interaction_step: int):
|
||||
if policy_fps < cfg.fps:
|
||||
logging.warning(
|
||||
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}"
|
||||
|
||||
@@ -14,16 +14,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import functools
|
||||
import io
|
||||
import os
|
||||
import pickle
|
||||
from typing import Any, Callable, Optional, Sequence, TypedDict
|
||||
|
||||
import io
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
import os
|
||||
import pickle
|
||||
|
||||
|
||||
class Transition(TypedDict):
|
||||
@@ -45,38 +45,27 @@ class BatchTransition(TypedDict):
|
||||
truncated: torch.Tensor
|
||||
|
||||
|
||||
def move_transition_to_device(
|
||||
transition: Transition, device: str = "cpu"
|
||||
) -> Transition:
|
||||
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
|
||||
# Move state tensors to CPU
|
||||
device = torch.device(device)
|
||||
transition["state"] = {
|
||||
key: val.to(device, non_blocking=device.type == "cuda")
|
||||
for key, val in transition["state"].items()
|
||||
key: val.to(device, non_blocking=device.type == "cuda") for key, val in transition["state"].items()
|
||||
}
|
||||
|
||||
# Move action to CPU
|
||||
transition["action"] = transition["action"].to(
|
||||
device, non_blocking=device.type == "cuda"
|
||||
)
|
||||
transition["action"] = transition["action"].to(device, non_blocking=device.type == "cuda")
|
||||
|
||||
# No need to move reward or done, as they are float and bool
|
||||
|
||||
# No need to move reward or done, as they are float and bool
|
||||
if isinstance(transition["reward"], torch.Tensor):
|
||||
transition["reward"] = transition["reward"].to(
|
||||
device=device, non_blocking=device.type == "cuda"
|
||||
)
|
||||
transition["reward"] = transition["reward"].to(device=device, non_blocking=device.type == "cuda")
|
||||
|
||||
if isinstance(transition["done"], torch.Tensor):
|
||||
transition["done"] = transition["done"].to(
|
||||
device, non_blocking=device.type == "cuda"
|
||||
)
|
||||
transition["done"] = transition["done"].to(device, non_blocking=device.type == "cuda")
|
||||
|
||||
if isinstance(transition["truncated"], torch.Tensor):
|
||||
transition["truncated"] = transition["truncated"].to(
|
||||
device, non_blocking=device.type == "cuda"
|
||||
)
|
||||
transition["truncated"] = transition["truncated"].to(device, non_blocking=device.type == "cuda")
|
||||
|
||||
# Move next_state tensors to CPU
|
||||
transition["next_state"] = {
|
||||
@@ -100,10 +89,7 @@ def move_state_dict_to_device(state_dict, device="cpu"):
|
||||
if isinstance(state_dict, torch.Tensor):
|
||||
return state_dict.to(device)
|
||||
elif isinstance(state_dict, dict):
|
||||
return {
|
||||
k: move_state_dict_to_device(v, device=device)
|
||||
for k, v in state_dict.items()
|
||||
}
|
||||
return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()}
|
||||
elif isinstance(state_dict, list):
|
||||
return [move_state_dict_to_device(v, device=device) for v in state_dict]
|
||||
elif isinstance(state_dict, tuple):
|
||||
@@ -174,9 +160,7 @@ def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Te
|
||||
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
|
||||
|
||||
# Gather pixels
|
||||
cropped_hwcn = images_hwcn[
|
||||
torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :
|
||||
]
|
||||
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
|
||||
# cropped_hwcn => (B, crop_h, crop_w, C)
|
||||
|
||||
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
|
||||
@@ -223,9 +207,7 @@ class ReplayBuffer:
|
||||
self.optimize_memory = optimize_memory
|
||||
|
||||
# Track episode boundaries for memory optimization
|
||||
self.episode_ends = torch.zeros(
|
||||
capacity, dtype=torch.bool, device=storage_device
|
||||
)
|
||||
self.episode_ends = torch.zeros(capacity, dtype=torch.bool, device=storage_device)
|
||||
|
||||
# If no state_keys provided, default to an empty list
|
||||
self.state_keys = state_keys if state_keys is not None else []
|
||||
@@ -246,9 +228,7 @@ class ReplayBuffer:
|
||||
key: torch.empty((self.capacity, *shape), device=self.storage_device)
|
||||
for key, shape in state_shapes.items()
|
||||
}
|
||||
self.actions = torch.empty(
|
||||
(self.capacity, *action_shape), device=self.storage_device
|
||||
)
|
||||
self.actions = torch.empty((self.capacity, *action_shape), device=self.storage_device)
|
||||
self.rewards = torch.empty((self.capacity,), device=self.storage_device)
|
||||
|
||||
if not self.optimize_memory:
|
||||
@@ -262,12 +242,8 @@ class ReplayBuffer:
|
||||
# Just create a reference to states for consistent API
|
||||
self.next_states = self.states # Just a reference for API consistency
|
||||
|
||||
self.dones = torch.empty(
|
||||
(self.capacity,), dtype=torch.bool, device=self.storage_device
|
||||
)
|
||||
self.truncateds = torch.empty(
|
||||
(self.capacity,), dtype=torch.bool, device=self.storage_device
|
||||
)
|
||||
self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||
self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||
self.initialized = True
|
||||
|
||||
def __len__(self):
|
||||
@@ -294,9 +270,7 @@ class ReplayBuffer:
|
||||
|
||||
if not self.optimize_memory:
|
||||
# Only store next_states if not optimizing memory
|
||||
self.next_states[key][self.position].copy_(
|
||||
next_state[key].squeeze(dim=0)
|
||||
)
|
||||
self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0))
|
||||
|
||||
self.actions[self.position].copy_(action.squeeze(dim=0))
|
||||
self.rewards[self.position] = reward
|
||||
@@ -309,23 +283,15 @@ class ReplayBuffer:
|
||||
def sample(self, batch_size: int) -> BatchTransition:
|
||||
"""Sample a random batch of transitions and collate them into batched tensors."""
|
||||
if not self.initialized:
|
||||
raise RuntimeError(
|
||||
"Cannot sample from an empty buffer. Add transitions first."
|
||||
)
|
||||
raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.")
|
||||
|
||||
batch_size = min(batch_size, self.size)
|
||||
|
||||
# Random indices for sampling - create on the same device as storage
|
||||
idx = torch.randint(
|
||||
low=0, high=self.size, size=(batch_size,), device=self.storage_device
|
||||
)
|
||||
idx = torch.randint(low=0, high=self.size, size=(batch_size,), device=self.storage_device)
|
||||
|
||||
# Identify image keys that need augmentation
|
||||
image_keys = (
|
||||
[k for k in self.states if k.startswith("observation.image")]
|
||||
if self.use_drq
|
||||
else []
|
||||
)
|
||||
image_keys = [k for k in self.states if k.startswith("observation.image")] if self.use_drq else []
|
||||
|
||||
# Create batched state and next_state
|
||||
batch_state = {}
|
||||
@@ -358,13 +324,9 @@ class ReplayBuffer:
|
||||
# Split the augmented images back to their sources
|
||||
for i, key in enumerate(image_keys):
|
||||
# State images are at even indices (0, 2, 4...)
|
||||
batch_state[key] = augmented_images[
|
||||
i * 2 * batch_size : (i * 2 + 1) * batch_size
|
||||
]
|
||||
batch_state[key] = augmented_images[i * 2 * batch_size : (i * 2 + 1) * batch_size]
|
||||
# Next state images are at odd indices (1, 3, 5...)
|
||||
batch_next_state[key] = augmented_images[
|
||||
(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size
|
||||
]
|
||||
batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size]
|
||||
|
||||
# Sample other tensors
|
||||
batch_actions = self.actions[idx].to(self.device)
|
||||
@@ -434,16 +396,12 @@ class ReplayBuffer:
|
||||
)
|
||||
|
||||
# Convert dataset to transitions
|
||||
list_transition = cls._lerobotdataset_to_transitions(
|
||||
dataset=lerobot_dataset, state_keys=state_keys
|
||||
)
|
||||
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
|
||||
|
||||
# Initialize the buffer with the first transition to set up storage tensors
|
||||
if list_transition:
|
||||
first_transition = list_transition[0]
|
||||
first_state = {
|
||||
k: v.to(device) for k, v in first_transition["state"].items()
|
||||
}
|
||||
first_state = {k: v.to(device) for k, v in first_transition["state"].items()}
|
||||
first_action = first_transition["action"].to(device)
|
||||
|
||||
# Apply action mask/delta if needed
|
||||
@@ -541,9 +499,7 @@ class ReplayBuffer:
|
||||
|
||||
# Convert transitions into episodes and frames
|
||||
episode_index = 0
|
||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
|
||||
episode_index=episode_index
|
||||
)
|
||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index=episode_index)
|
||||
|
||||
frame_idx_in_episode = 0
|
||||
for idx in range(self.size):
|
||||
@@ -557,12 +513,8 @@ class ReplayBuffer:
|
||||
|
||||
# Fill action, reward, done
|
||||
frame_dict["action"] = self.actions[actual_idx].cpu()
|
||||
frame_dict["next.reward"] = torch.tensor(
|
||||
[self.rewards[actual_idx]], dtype=torch.float32
|
||||
).cpu()
|
||||
frame_dict["next.done"] = torch.tensor(
|
||||
[self.dones[actual_idx]], dtype=torch.bool
|
||||
).cpu()
|
||||
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
|
||||
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
|
||||
|
||||
# Add to the dataset's buffer
|
||||
lerobot_dataset.add_frame(frame_dict)
|
||||
@@ -619,9 +571,7 @@ class ReplayBuffer:
|
||||
A list of Transition dictionaries with the same length as `dataset`.
|
||||
"""
|
||||
if state_keys is None:
|
||||
raise ValueError(
|
||||
"State keys must be provided when converting LeRobotDataset to Transitions."
|
||||
)
|
||||
raise ValueError("State keys must be provided when converting LeRobotDataset to Transitions.")
|
||||
|
||||
transitions = []
|
||||
num_frames = len(dataset)
|
||||
@@ -632,9 +582,7 @@ class ReplayBuffer:
|
||||
|
||||
# If not, we need to infer it from episode boundaries
|
||||
if not has_done_key:
|
||||
print(
|
||||
"'next.done' key not found in dataset. Inferring from episode boundaries..."
|
||||
)
|
||||
print("'next.done' key not found in dataset. Inferring from episode boundaries...")
|
||||
|
||||
for i in tqdm(range(num_frames)):
|
||||
current_sample = dataset[i]
|
||||
@@ -886,8 +834,7 @@ if __name__ == "__main__":
|
||||
# We need to be careful because we don't know the original index
|
||||
# So we check if the increment is roughly 0.01
|
||||
next_state_check = (
|
||||
abs(next_state_sig - state_sig - 0.01) < 1e-4
|
||||
or abs(next_state_sig - state_sig) < 1e-4
|
||||
abs(next_state_sig - state_sig - 0.01) < 1e-4 or abs(next_state_sig - state_sig) < 1e-4
|
||||
)
|
||||
|
||||
# Count correct relationships
|
||||
@@ -901,17 +848,11 @@ if __name__ == "__main__":
|
||||
total_checks += 3
|
||||
|
||||
alignment_accuracy = 100.0 * correct_relationships / total_checks
|
||||
print(
|
||||
f"State-action-reward-next_state alignment accuracy: {alignment_accuracy:.2f}%"
|
||||
)
|
||||
print(f"State-action-reward-next_state alignment accuracy: {alignment_accuracy:.2f}%")
|
||||
if alignment_accuracy > 99.0:
|
||||
print(
|
||||
"✅ All relationships verified! Buffer maintains correct temporal relationships."
|
||||
)
|
||||
print("✅ All relationships verified! Buffer maintains correct temporal relationships.")
|
||||
else:
|
||||
print(
|
||||
"⚠️ Some relationships don't match expected patterns. Buffer may have alignment issues."
|
||||
)
|
||||
print("⚠️ Some relationships don't match expected patterns. Buffer may have alignment issues.")
|
||||
|
||||
# Print some debug information about failures
|
||||
print("\nDebug information for failed checks:")
|
||||
@@ -973,18 +914,14 @@ if __name__ == "__main__":
|
||||
|
||||
# Verify consistency before and after conversion
|
||||
original_states = batch["state"]["observation.image"].mean().item()
|
||||
reconverted_states = (
|
||||
reconverted_batch["state"]["observation.image"].mean().item()
|
||||
)
|
||||
reconverted_states = reconverted_batch["state"]["observation.image"].mean().item()
|
||||
print(f"Original buffer state mean: {original_states:.4f}")
|
||||
print(f"Reconverted buffer state mean: {reconverted_states:.4f}")
|
||||
|
||||
if abs(original_states - reconverted_states) < 1.0:
|
||||
print("Values are reasonably similar - conversion works as expected")
|
||||
else:
|
||||
print(
|
||||
"WARNING: Significant difference between original and reconverted values"
|
||||
)
|
||||
print("WARNING: Significant difference between original and reconverted values")
|
||||
|
||||
print("\nAll previous tests completed!")
|
||||
|
||||
@@ -1093,15 +1030,11 @@ if __name__ == "__main__":
|
||||
all_indices = torch.arange(sequential_batch_size, device=test_buffer.storage_device)
|
||||
|
||||
# Get state tensors
|
||||
batch_state = {
|
||||
"value": test_buffer.states["value"][all_indices].to(test_buffer.device)
|
||||
}
|
||||
batch_state = {"value": test_buffer.states["value"][all_indices].to(test_buffer.device)}
|
||||
|
||||
# Get next_state using memory-optimized approach (simply index+1)
|
||||
next_indices = (all_indices + 1) % test_buffer.capacity
|
||||
batch_next_state = {
|
||||
"value": test_buffer.states["value"][next_indices].to(test_buffer.device)
|
||||
}
|
||||
batch_next_state = {"value": test_buffer.states["value"][next_indices].to(test_buffer.device)}
|
||||
|
||||
# Get other tensors
|
||||
batch_dones = test_buffer.dones[all_indices].to(test_buffer.device)
|
||||
@@ -1121,9 +1054,7 @@ if __name__ == "__main__":
|
||||
print("- We always use the next state in the buffer (index+1) as next_state")
|
||||
print("- For terminal states, this means using the first state of the next episode")
|
||||
print("- This is a common tradeoff in RL implementations for memory efficiency")
|
||||
print(
|
||||
"- Since we track done flags, the algorithm can handle these transitions correctly"
|
||||
)
|
||||
print("- Since we track done flags, the algorithm can handle these transitions correctly")
|
||||
|
||||
# Test random sampling
|
||||
print("\nVerifying random sampling with simplified memory optimization...")
|
||||
@@ -1137,23 +1068,19 @@ if __name__ == "__main__":
|
||||
# Print a few samples
|
||||
print("Random samples - State, Next State, Done (First 10):")
|
||||
for i in range(10):
|
||||
print(
|
||||
f" {random_state_values[i]:.1f} → {random_next_values[i]:.1f}, Done: {random_done_flags[i]}"
|
||||
)
|
||||
print(f" {random_state_values[i]:.1f} → {random_next_values[i]:.1f}, Done: {random_done_flags[i]}")
|
||||
|
||||
# Calculate memory savings
|
||||
# Assume optimized_buffer and standard_buffer have already been initialized and filled
|
||||
std_mem = (
|
||||
sum(
|
||||
standard_buffer.states[key].nelement()
|
||||
* standard_buffer.states[key].element_size()
|
||||
standard_buffer.states[key].nelement() * standard_buffer.states[key].element_size()
|
||||
for key in standard_buffer.states
|
||||
)
|
||||
* 2
|
||||
)
|
||||
opt_mem = sum(
|
||||
optimized_buffer.states[key].nelement()
|
||||
* optimized_buffer.states[key].element_size()
|
||||
optimized_buffer.states[key].nelement() * optimized_buffer.states[key].element_size()
|
||||
for key in optimized_buffer.states
|
||||
)
|
||||
|
||||
|
||||
@@ -225,9 +225,7 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Crop rectangular ROIs from a LeRobot dataset."
|
||||
)
|
||||
parser = argparse.ArgumentParser(description="Crop rectangular ROIs from a LeRobot dataset.")
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
@@ -249,9 +247,7 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
local_files_only = args.root is not None
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=args.repo_id, root=args.root, local_files_only=local_files_only
|
||||
)
|
||||
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root, local_files_only=local_files_only)
|
||||
|
||||
images = get_image_from_lerobot_dataset(dataset)
|
||||
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||
import argparse
|
||||
import logging
|
||||
import time
|
||||
import torch
|
||||
import numpy as np
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
@@ -187,9 +188,7 @@ class KeyboardController(InputController):
|
||||
class GamepadController(InputController):
|
||||
"""Generate motion deltas from gamepad input."""
|
||||
|
||||
def __init__(
|
||||
self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01, deadzone=0.1
|
||||
):
|
||||
def __init__(self, x_step_size=0.01, y_step_size=0.01, z_step_size=0.01, deadzone=0.1):
|
||||
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||
self.deadzone = deadzone
|
||||
self.joystick = None
|
||||
@@ -203,9 +202,7 @@ class GamepadController(InputController):
|
||||
pygame.joystick.init()
|
||||
|
||||
if pygame.joystick.get_count() == 0:
|
||||
logging.error(
|
||||
"No gamepad detected. Please connect a gamepad and try again."
|
||||
)
|
||||
logging.error("No gamepad detected. Please connect a gamepad and try again.")
|
||||
self.running = False
|
||||
return
|
||||
|
||||
@@ -338,18 +335,12 @@ class GamepadControllerHID(InputController):
|
||||
|
||||
devices = hid.enumerate()
|
||||
for device in devices:
|
||||
if (
|
||||
device["vendor_id"] == self.vendor_id
|
||||
and device["product_id"] == self.product_id
|
||||
):
|
||||
logging.info(
|
||||
f"Found gamepad: {device.get('product_string', 'Unknown')}"
|
||||
)
|
||||
if device["vendor_id"] == self.vendor_id and device["product_id"] == self.product_id:
|
||||
logging.info(f"Found gamepad: {device.get('product_string', 'Unknown')}")
|
||||
return device
|
||||
|
||||
logging.error(
|
||||
f"No gamepad with vendor ID 0x{self.vendor_id:04X} and "
|
||||
f"product ID 0x{self.product_id:04X} found"
|
||||
f"No gamepad with vendor ID 0x{self.vendor_id:04X} and product ID 0x{self.product_id:04X} found"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -381,9 +372,7 @@ class GamepadControllerHID(InputController):
|
||||
|
||||
except OSError as e:
|
||||
logging.error(f"Error opening gamepad: {e}")
|
||||
logging.error(
|
||||
"You might need to run this with sudo/admin privileges on some systems"
|
||||
)
|
||||
logging.error("You might need to run this with sudo/admin privileges on some systems")
|
||||
self.running = False
|
||||
|
||||
def stop(self):
|
||||
@@ -421,12 +410,8 @@ class GamepadControllerHID(InputController):
|
||||
# Apply deadzone
|
||||
self.left_x = 0 if abs(self.left_x) < self.deadzone else self.left_x
|
||||
self.left_y = 0 if abs(self.left_y) < self.deadzone else self.left_y
|
||||
self.right_x = (
|
||||
0 if abs(self.right_x) < self.deadzone else self.right_x
|
||||
)
|
||||
self.right_y = (
|
||||
0 if abs(self.right_y) < self.deadzone else self.right_y
|
||||
)
|
||||
self.right_x = 0 if abs(self.right_x) < self.deadzone else self.right_x
|
||||
self.right_y = 0 if abs(self.right_y) < self.deadzone else self.right_y
|
||||
|
||||
# Parse button states (byte 5 in the Logitech RumblePad 2)
|
||||
buttons = data[5]
|
||||
@@ -493,9 +478,7 @@ def test_inverse_kinematics(robot, fps=10):
|
||||
joint_positions = obs["observation.state"].cpu().numpy()
|
||||
ee_pos = RobotKinematics.fk_gripper_tip(joint_positions)
|
||||
desired_ee_pos = ee_pos
|
||||
target_joint_state = RobotKinematics.ik(
|
||||
joint_positions, desired_ee_pos, position_only=True
|
||||
)
|
||||
target_joint_state = RobotKinematics.ik(joint_positions, desired_ee_pos, position_only=True)
|
||||
robot.send_action(torch.from_numpy(target_joint_state))
|
||||
logging.info(f"Target Joint State: {target_joint_state}")
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
@@ -573,17 +556,13 @@ def teleoperate_delta_inverse_kinematics_with_leader(robot, fps=10):
|
||||
robot.send_action(torch.from_numpy(target_joint_state))
|
||||
|
||||
# Logging
|
||||
logging.info(
|
||||
f"Current EE: {current_ee_pos[:3, 3]}, Desired EE: {desired_ee_pos[:3, 3]}"
|
||||
)
|
||||
logging.info(f"Current EE: {current_ee_pos[:3, 3]}, Desired EE: {desired_ee_pos[:3, 3]}")
|
||||
logging.info(f"Delta EE: {ee_delta[:3, 3]}")
|
||||
|
||||
busy_wait(1 / fps - (time.perf_counter() - loop_start_time))
|
||||
|
||||
|
||||
def teleoperate_delta_inverse_kinematics(
|
||||
robot, controller, fps=10, bounds=None, fk_func=None
|
||||
):
|
||||
def teleoperate_delta_inverse_kinematics(robot, controller, fps=10, bounds=None, fk_func=None):
|
||||
"""
|
||||
Control a robot using delta end-effector movements from any input controller.
|
||||
|
||||
@@ -597,9 +576,7 @@ def teleoperate_delta_inverse_kinematics(
|
||||
if fk_func is None:
|
||||
fk_func = RobotKinematics.fk_gripper_tip
|
||||
|
||||
logging.info(
|
||||
f"Testing Delta End-Effector Control with {controller.__class__.__name__}"
|
||||
)
|
||||
logging.info(f"Testing Delta End-Effector Control with {controller.__class__.__name__}")
|
||||
|
||||
# Initial position capture
|
||||
obs = robot.capture_observation()
|
||||
@@ -631,9 +608,7 @@ def teleoperate_delta_inverse_kinematics(
|
||||
|
||||
# Apply bounds if provided
|
||||
if bounds is not None:
|
||||
desired_ee_pos[:3, 3] = np.clip(
|
||||
desired_ee_pos[:3, 3], bounds["min"], bounds["max"]
|
||||
)
|
||||
desired_ee_pos[:3, 3] = np.clip(desired_ee_pos[:3, 3], bounds["min"], bounds["max"])
|
||||
|
||||
# Only send commands if there's actual movement
|
||||
if any([abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]]):
|
||||
@@ -684,14 +659,10 @@ def teleoperate_gym_env(env, controller, fps: int = 30):
|
||||
if any([abs(v) > 0.001 for v in [delta_x, delta_y, delta_z]]):
|
||||
# Step the environment - pass action as a tensor with intervention flag
|
||||
action_tensor = torch.from_numpy(action.astype(np.float32))
|
||||
obs, reward, terminated, truncated, info = env.step(
|
||||
(action_tensor, False)
|
||||
)
|
||||
obs, reward, terminated, truncated, info = env.step((action_tensor, False))
|
||||
|
||||
# Log information
|
||||
logging.info(
|
||||
f"Action: [{delta_x:.4f}, {delta_y:.4f}, {delta_z:.4f}]"
|
||||
)
|
||||
logging.info(f"Action: [{delta_x:.4f}, {delta_y:.4f}, {delta_z:.4f}]")
|
||||
logging.info(f"Reward: {reward}")
|
||||
|
||||
# Reset if episode ended
|
||||
@@ -761,20 +732,14 @@ if __name__ == "__main__":
|
||||
# Determine controller type based on mode prefix
|
||||
controller = None
|
||||
if args.mode.startswith("keyboard"):
|
||||
controller = KeyboardController(
|
||||
x_step_size=0.01, y_step_size=0.01, z_step_size=0.05
|
||||
)
|
||||
controller = KeyboardController(x_step_size=0.01, y_step_size=0.01, z_step_size=0.05)
|
||||
elif args.mode.startswith("gamepad"):
|
||||
controller = GamepadController(
|
||||
x_step_size=0.02, y_step_size=0.02, z_step_size=0.05
|
||||
)
|
||||
controller = GamepadController(x_step_size=0.02, y_step_size=0.02, z_step_size=0.05)
|
||||
|
||||
# Handle mode categories
|
||||
if args.mode in ["keyboard", "gamepad"]:
|
||||
# Direct robot control modes
|
||||
teleoperate_delta_inverse_kinematics(
|
||||
robot, controller, bounds=bounds, fps=10
|
||||
)
|
||||
teleoperate_delta_inverse_kinematics(robot, controller, bounds=bounds, fps=10)
|
||||
|
||||
elif args.mode in ["keyboard_gym", "gamepad_gym"]:
|
||||
# Gym environment control modes
|
||||
|
||||
@@ -32,9 +32,7 @@ def find_joint_bounds(
|
||||
if display_cameras and not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(
|
||||
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
|
||||
)
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
|
||||
if time.perf_counter() - start_episode_t > control_time_s:
|
||||
@@ -69,9 +67,7 @@ def find_ee_bounds(
|
||||
if display_cameras and not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(
|
||||
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
|
||||
)
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
|
||||
if time.perf_counter() - start_episode_t > control_time_s:
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from threading import Lock
|
||||
from typing import Annotated, Any, Dict, Tuple
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -18,7 +18,6 @@ from lerobot.common.robot_devices.control_utils import (
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
from lerobot.common.utils.utils import init_hydra_config, log_say
|
||||
|
||||
from lerobot.scripts.server.kinematics import RobotKinematics
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -67,9 +66,7 @@ class HILSerlRobotEnv(gym.Env):
|
||||
if not self.robot.is_connected:
|
||||
self.robot.connect()
|
||||
|
||||
self.initial_follower_position = robot.follower_arms["main"].read(
|
||||
"Present_Position"
|
||||
)
|
||||
self.initial_follower_position = robot.follower_arms["main"].read("Present_Position")
|
||||
|
||||
# Episode tracking.
|
||||
self.current_step = 0
|
||||
@@ -77,9 +74,7 @@ class HILSerlRobotEnv(gym.Env):
|
||||
|
||||
self.delta = delta
|
||||
self.use_delta_action_space = use_delta_action_space
|
||||
self.current_joint_positions = self.robot.follower_arms["main"].read(
|
||||
"Present_Position"
|
||||
)
|
||||
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
|
||||
|
||||
# Retrieve the size of the joint position interval bound.
|
||||
self.relative_bounds_size = (
|
||||
@@ -92,9 +87,7 @@ class HILSerlRobotEnv(gym.Env):
|
||||
)
|
||||
|
||||
self.robot.config.max_relative_target = (
|
||||
self.relative_bounds_size.float()
|
||||
if self.relative_bounds_size is not None
|
||||
else None
|
||||
self.relative_bounds_size.float() if self.relative_bounds_size is not None else None
|
||||
)
|
||||
|
||||
# Dynamically configure the observation and action spaces.
|
||||
@@ -119,9 +112,7 @@ class HILSerlRobotEnv(gym.Env):
|
||||
# Define observation spaces for images and other states.
|
||||
image_keys = [key for key in example_obs if "image" in key]
|
||||
observation_spaces = {
|
||||
key: gym.spaces.Box(
|
||||
low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8
|
||||
)
|
||||
key: gym.spaces.Box(low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8)
|
||||
for key in image_keys
|
||||
}
|
||||
observation_spaces["observation.state"] = gym.spaces.Box(
|
||||
@@ -172,9 +163,7 @@ class HILSerlRobotEnv(gym.Env):
|
||||
),
|
||||
)
|
||||
|
||||
def reset(
|
||||
self, seed=None, options=None
|
||||
) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
|
||||
def reset(self, seed=None, options=None) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
|
||||
"""
|
||||
Reset the environment to its initial state.
|
||||
This method resets the step counter and clears any episodic data.
|
||||
@@ -231,35 +220,25 @@ class HILSerlRobotEnv(gym.Env):
|
||||
"""
|
||||
policy_action, intervention_bool = action
|
||||
teleop_action = None
|
||||
self.current_joint_positions = self.robot.follower_arms["main"].read(
|
||||
"Present_Position"
|
||||
)
|
||||
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
|
||||
if isinstance(policy_action, torch.Tensor):
|
||||
policy_action = policy_action.cpu().numpy()
|
||||
policy_action = np.clip(
|
||||
policy_action, self.action_space[0].low, self.action_space[0].high
|
||||
)
|
||||
policy_action = np.clip(policy_action, self.action_space[0].low, self.action_space[0].high)
|
||||
|
||||
if not intervention_bool:
|
||||
if self.use_delta_action_space:
|
||||
target_joint_positions = (
|
||||
self.current_joint_positions + self.delta * policy_action
|
||||
)
|
||||
target_joint_positions = self.current_joint_positions + self.delta * policy_action
|
||||
else:
|
||||
target_joint_positions = policy_action
|
||||
self.robot.send_action(torch.from_numpy(target_joint_positions))
|
||||
observation = self.robot.capture_observation()
|
||||
else:
|
||||
observation, teleop_action = self.robot.teleop_step(record_data=True)
|
||||
teleop_action = teleop_action[
|
||||
"action"
|
||||
] # Convert tensor to appropriate format
|
||||
teleop_action = teleop_action["action"] # Convert tensor to appropriate format
|
||||
|
||||
# When applying the delta action space, convert teleop absolute values to relative differences.
|
||||
if self.use_delta_action_space:
|
||||
teleop_action = (
|
||||
teleop_action - self.current_joint_positions
|
||||
) / self.delta
|
||||
teleop_action = (teleop_action - self.current_joint_positions) / self.delta
|
||||
if self.relative_bounds_size is not None and (
|
||||
torch.any(teleop_action < -self.relative_bounds_size)
|
||||
and torch.any(teleop_action > self.relative_bounds_size)
|
||||
@@ -333,12 +312,8 @@ class AddJointVelocityToObservation(gym.ObservationWrapper):
|
||||
|
||||
self.last_joint_positions = np.zeros(old_shape)
|
||||
|
||||
new_low = np.concatenate(
|
||||
[old_low, np.ones_like(old_low) * -joint_velocity_limits]
|
||||
)
|
||||
new_high = np.concatenate(
|
||||
[old_high, np.ones_like(old_high) * joint_velocity_limits]
|
||||
)
|
||||
new_low = np.concatenate([old_low, np.ones_like(old_low) * -joint_velocity_limits])
|
||||
new_high = np.concatenate([old_high, np.ones_like(old_high) * joint_velocity_limits])
|
||||
|
||||
new_shape = (old_shape[0] * 2,)
|
||||
|
||||
@@ -352,9 +327,7 @@ class AddJointVelocityToObservation(gym.ObservationWrapper):
|
||||
self.dt = 1.0 / fps
|
||||
|
||||
def observation(self, observation):
|
||||
joint_velocities = (
|
||||
observation["observation.state"] - self.last_joint_positions
|
||||
) / self.dt
|
||||
joint_velocities = (observation["observation.state"] - self.last_joint_positions) / self.dt
|
||||
self.last_joint_positions = observation["observation.state"].clone()
|
||||
observation["observation.state"] = torch.cat(
|
||||
[observation["observation.state"], joint_velocities], dim=-1
|
||||
@@ -439,9 +412,7 @@ class JointMaskingActionSpace(gym.Wrapper):
|
||||
raise ValueError("Mask length must match action space dimensions")
|
||||
low = env.action_space.low[self.active_dims]
|
||||
high = env.action_space.high[self.active_dims]
|
||||
self.action_space = gym.spaces.Box(
|
||||
low=low, high=high, dtype=env.action_space.dtype
|
||||
)
|
||||
self.action_space = gym.spaces.Box(low=low, high=high, dtype=env.action_space.dtype)
|
||||
|
||||
if isinstance(env.action_space, gym.spaces.Tuple):
|
||||
if len(mask) != env.action_space[0].shape[0]:
|
||||
@@ -449,12 +420,8 @@ class JointMaskingActionSpace(gym.Wrapper):
|
||||
|
||||
low = env.action_space[0].low[self.active_dims]
|
||||
high = env.action_space[0].high[self.active_dims]
|
||||
action_space_masked = gym.spaces.Box(
|
||||
low=low, high=high, dtype=env.action_space[0].dtype
|
||||
)
|
||||
self.action_space = gym.spaces.Tuple(
|
||||
(action_space_masked, env.action_space[1])
|
||||
)
|
||||
action_space_masked = gym.spaces.Box(low=low, high=high, dtype=env.action_space[0].dtype)
|
||||
self.action_space = gym.spaces.Tuple((action_space_masked, env.action_space[1]))
|
||||
# Create new action space with masked dimensions
|
||||
|
||||
def action(self, action):
|
||||
@@ -473,18 +440,14 @@ class JointMaskingActionSpace(gym.Wrapper):
|
||||
# Extract the masked component from the tuple.
|
||||
masked_action = action[0] if isinstance(action, tuple) else action
|
||||
# Create a full action for the Box element.
|
||||
full_box_action = np.zeros(
|
||||
self.env.action_space[0].shape, dtype=self.env.action_space[0].dtype
|
||||
)
|
||||
full_box_action = np.zeros(self.env.action_space[0].shape, dtype=self.env.action_space[0].dtype)
|
||||
full_box_action[self.active_dims] = masked_action
|
||||
# Return a tuple with the reconstructed Box action and the unchanged remainder.
|
||||
return (full_box_action, action[1])
|
||||
else:
|
||||
# For Box action spaces.
|
||||
masked_action = action if not isinstance(action, tuple) else action[0]
|
||||
full_action = np.zeros(
|
||||
self.env.action_space.shape, dtype=self.env.action_space.dtype
|
||||
)
|
||||
full_action = np.zeros(self.env.action_space.shape, dtype=self.env.action_space.dtype)
|
||||
full_action[self.active_dims] = masked_action
|
||||
return full_action
|
||||
|
||||
@@ -493,13 +456,9 @@ class JointMaskingActionSpace(gym.Wrapper):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
if "action_intervention" in info and info["action_intervention"] is not None:
|
||||
if info["action_intervention"].dim() == 1:
|
||||
info["action_intervention"] = info["action_intervention"][
|
||||
self.active_dims
|
||||
]
|
||||
info["action_intervention"] = info["action_intervention"][self.active_dims]
|
||||
else:
|
||||
info["action_intervention"] = info["action_intervention"][
|
||||
:, self.active_dims
|
||||
]
|
||||
info["action_intervention"] = info["action_intervention"][:, self.active_dims]
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
|
||||
@@ -555,9 +514,7 @@ class ImageCropResizeWrapper(gym.Wrapper):
|
||||
for key in crop_params_dict:
|
||||
top, left, height, width = crop_params_dict[key]
|
||||
new_shape = (top + height, left + width)
|
||||
self.observation_space[key] = gym.spaces.Box(
|
||||
low=0, high=255, shape=new_shape
|
||||
)
|
||||
self.observation_space[key] = gym.spaces.Box(low=0, high=255, shape=new_shape)
|
||||
|
||||
self.resize_size = resize_size
|
||||
if self.resize_size is None:
|
||||
@@ -583,9 +540,7 @@ class ImageCropResizeWrapper(gym.Wrapper):
|
||||
)
|
||||
# Check for NaNs before processing
|
||||
if torch.isnan(obs[k]).any():
|
||||
logging.error(
|
||||
f"NaN values detected in observation {k} before crop and resize"
|
||||
)
|
||||
logging.error(f"NaN values detected in observation {k} before crop and resize")
|
||||
|
||||
if device == torch.device("mps:0"):
|
||||
obs[k] = obs[k].cpu()
|
||||
@@ -595,9 +550,7 @@ class ImageCropResizeWrapper(gym.Wrapper):
|
||||
|
||||
# Check for NaNs after processing
|
||||
if torch.isnan(obs[k]).any():
|
||||
logging.error(
|
||||
f"NaN values detected in observation {k} after crop and resize"
|
||||
)
|
||||
logging.error(f"NaN values detected in observation {k} after crop and resize")
|
||||
|
||||
obs[k] = obs[k].to(device)
|
||||
|
||||
@@ -627,14 +580,10 @@ class ConvertToLeRobotObservation(gym.ObservationWrapper):
|
||||
observation = preprocess_observation(observation)
|
||||
|
||||
observation = {
|
||||
key: observation[key].to(
|
||||
self.device, non_blocking=self.device.type == "cuda"
|
||||
)
|
||||
key: observation[key].to(self.device, non_blocking=self.device.type == "cuda")
|
||||
for key in observation
|
||||
}
|
||||
observation = {
|
||||
k: torch.tensor(v, device=self.device) for k, v in observation.items()
|
||||
}
|
||||
observation = {k: torch.tensor(v, device=self.device) for k, v in observation.items()}
|
||||
return observation
|
||||
|
||||
|
||||
@@ -686,26 +635,16 @@ class KeyboardInterfaceWrapper(gym.Wrapper):
|
||||
play_sounds=True,
|
||||
)
|
||||
return
|
||||
if (
|
||||
self.events["pause_policy"]
|
||||
and not self.events["human_intervention_step"]
|
||||
):
|
||||
if self.events["pause_policy"] and not self.events["human_intervention_step"]:
|
||||
self.events["human_intervention_step"] = True
|
||||
print("Space key pressed. Human intervention starting.")
|
||||
log_say(
|
||||
"Starting human intervention.", play_sounds=True
|
||||
)
|
||||
log_say("Starting human intervention.", play_sounds=True)
|
||||
return
|
||||
if (
|
||||
self.events["pause_policy"]
|
||||
and self.events["human_intervention_step"]
|
||||
):
|
||||
if self.events["pause_policy"] and self.events["human_intervention_step"]:
|
||||
self.events["pause_policy"] = False
|
||||
self.events["human_intervention_step"] = False
|
||||
print("Space key pressed for a third time.")
|
||||
log_say(
|
||||
"Continuing with policy actions.", play_sounds=True
|
||||
)
|
||||
log_say("Continuing with policy actions.", play_sounds=True)
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"Error handling key press: {e}")
|
||||
@@ -713,9 +652,7 @@ class KeyboardInterfaceWrapper(gym.Wrapper):
|
||||
self.listener = keyboard.Listener(on_press=on_press)
|
||||
self.listener.start()
|
||||
except ImportError:
|
||||
logging.warning(
|
||||
"Could not import pynput. Keyboard interface will not be available."
|
||||
)
|
||||
logging.warning("Could not import pynput. Keyboard interface will not be available.")
|
||||
self.listener = None
|
||||
|
||||
def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]:
|
||||
@@ -742,9 +679,7 @@ class KeyboardInterfaceWrapper(gym.Wrapper):
|
||||
time.sleep(0.1) # Check more frequently if desired
|
||||
|
||||
# Execute the step in the underlying environment
|
||||
obs, reward, terminated, truncated, info = self.env.step(
|
||||
(policy_action, is_intervention)
|
||||
)
|
||||
obs, reward, terminated, truncated, info = self.env.step((policy_action, is_intervention))
|
||||
|
||||
# Override reward and termination if episode success event triggered
|
||||
with self.event_lock:
|
||||
@@ -807,9 +742,7 @@ class BatchCompitableWrapper(gym.ObservationWrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
|
||||
def observation(
|
||||
self, observation: dict[str, torch.Tensor]
|
||||
) -> dict[str, torch.Tensor]:
|
||||
def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
for key in observation:
|
||||
if "image" in key and observation[key].dim() == 3:
|
||||
observation[key] = observation[key].unsqueeze(0)
|
||||
@@ -844,9 +777,7 @@ class EEActionWrapper(gym.ActionWrapper):
|
||||
dtype=np.float32,
|
||||
)
|
||||
if isinstance(self.action_space, gym.spaces.Tuple):
|
||||
self.action_space = gym.spaces.Tuple(
|
||||
(ee_action_space, self.action_space[1])
|
||||
)
|
||||
self.action_space = gym.spaces.Tuple((ee_action_space, self.action_space[1]))
|
||||
else:
|
||||
self.action_space = ee_action_space
|
||||
|
||||
@@ -858,9 +789,7 @@ class EEActionWrapper(gym.ActionWrapper):
|
||||
if isinstance(action, tuple):
|
||||
action, _ = action
|
||||
|
||||
current_joint_pos = self.unwrapped.robot.follower_arms["main"].read(
|
||||
"Present_Position"
|
||||
)
|
||||
current_joint_pos = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
|
||||
current_ee_pos = self.fk_function(current_joint_pos)
|
||||
if isinstance(action, torch.Tensor):
|
||||
action = action.cpu().numpy()
|
||||
@@ -898,9 +827,7 @@ class EEObservationWrapper(gym.ObservationWrapper):
|
||||
self.fk_function = self.kinematics.fk_gripper_tip
|
||||
|
||||
def observation(self, observation):
|
||||
current_joint_pos = self.unwrapped.robot.follower_arms["main"].read(
|
||||
"Present_Position"
|
||||
)
|
||||
current_joint_pos = self.unwrapped.robot.follower_arms["main"].read("Present_Position")
|
||||
current_ee_pos = self.fk_function(current_joint_pos)
|
||||
observation["observation.state"] = torch.cat(
|
||||
[
|
||||
@@ -944,8 +871,8 @@ class GamepadControlWrapper(gym.Wrapper):
|
||||
"""
|
||||
super().__init__(env)
|
||||
from lerobot.scripts.server.end_effector_control_utils import (
|
||||
GamepadControllerHID,
|
||||
GamepadController,
|
||||
GamepadControllerHID,
|
||||
)
|
||||
|
||||
# use HidApi for macos
|
||||
@@ -1027,9 +954,7 @@ class GamepadControlWrapper(gym.Wrapper):
|
||||
|
||||
# Update episode ending state if requested
|
||||
if terminate_episode:
|
||||
logging.info(
|
||||
f"Episode manually ended: {'SUCCESS' if success else 'FAILURE'}"
|
||||
)
|
||||
logging.info(f"Episode manually ended: {'SUCCESS' if success else 'FAILURE'}")
|
||||
|
||||
# Only override the action if gamepad is active
|
||||
if is_intervention:
|
||||
@@ -1054,9 +979,7 @@ class GamepadControlWrapper(gym.Wrapper):
|
||||
logging.info("Episode ended successfully with reward 1.0")
|
||||
|
||||
info["is_intervention"] = is_intervention
|
||||
action_intervention = (
|
||||
final_action[0] if isinstance(final_action, Tuple) else final_action
|
||||
)
|
||||
action_intervention = final_action[0] if isinstance(final_action, Tuple) else final_action
|
||||
if isinstance(action_intervention, np.ndarray):
|
||||
action_intervention = torch.from_numpy(action_intervention)
|
||||
info["action_intervention"] = action_intervention
|
||||
@@ -1087,9 +1010,7 @@ class GamepadControlWrapper(gym.Wrapper):
|
||||
class ActionScaleWrapper(gym.ActionWrapper):
|
||||
def __init__(self, env, ee_action_space_params=None):
|
||||
super().__init__(env)
|
||||
assert ee_action_space_params is not None, (
|
||||
"TODO: method implemented for ee action space only so far"
|
||||
)
|
||||
assert ee_action_space_params is not None, "TODO: method implemented for ee action space only so far"
|
||||
self.scale_vector = np.array(
|
||||
[
|
||||
[
|
||||
@@ -1148,9 +1069,7 @@ def make_robot_env(
|
||||
if cfg.env.wrapper.add_joint_velocity_to_observation:
|
||||
env = AddJointVelocityToObservation(env=env, fps=cfg.fps)
|
||||
if cfg.env.wrapper.add_ee_pose_to_observation:
|
||||
env = EEObservationWrapper(
|
||||
env=env, ee_pose_limits=cfg.env.wrapper.ee_action_space_params.bounds
|
||||
)
|
||||
env = EEObservationWrapper(env=env, ee_pose_limits=cfg.env.wrapper.ee_action_space_params.bounds)
|
||||
|
||||
env = ConvertToLeRobotObservation(env=env, device=cfg.env.device)
|
||||
|
||||
@@ -1163,13 +1082,9 @@ def make_robot_env(
|
||||
|
||||
# Add reward computation and control wrappers
|
||||
# env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
|
||||
env = TimeLimitWrapper(
|
||||
env=env, control_time_s=cfg.env.wrapper.control_time_s, fps=cfg.fps
|
||||
)
|
||||
env = TimeLimitWrapper(env=env, control_time_s=cfg.env.wrapper.control_time_s, fps=cfg.fps)
|
||||
if cfg.env.wrapper.ee_action_space_params is not None:
|
||||
env = EEActionWrapper(
|
||||
env=env, ee_action_space_params=cfg.env.wrapper.ee_action_space_params
|
||||
)
|
||||
env = EEActionWrapper(env=env, ee_action_space_params=cfg.env.wrapper.ee_action_space_params)
|
||||
if (
|
||||
cfg.env.wrapper.ee_action_space_params is not None
|
||||
and cfg.env.wrapper.ee_action_space_params.use_gamepad
|
||||
@@ -1193,9 +1108,7 @@ def make_robot_env(
|
||||
cfg.env.wrapper.ee_action_space_params is None
|
||||
and cfg.env.wrapper.joint_masking_action_space is not None
|
||||
):
|
||||
env = JointMaskingActionSpace(
|
||||
env=env, mask=cfg.env.wrapper.joint_masking_action_space
|
||||
)
|
||||
env = JointMaskingActionSpace(env=env, mask=cfg.env.wrapper.joint_masking_action_space)
|
||||
env = BatchCompitableWrapper(env=env)
|
||||
|
||||
return env
|
||||
@@ -1216,9 +1129,7 @@ def get_classifier(pretrained_path, config_path, device="mps"):
|
||||
cfg = init_hydra_config(config_path)
|
||||
|
||||
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
||||
classifier_config.num_cameras = len(
|
||||
cfg.training.image_keys
|
||||
) # TODO automate these paths
|
||||
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
|
||||
model = Classifier(classifier_config)
|
||||
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
|
||||
model = model.to(device)
|
||||
@@ -1317,9 +1228,7 @@ def record_dataset(
|
||||
|
||||
# For teleop, get action from intervention
|
||||
if policy is None:
|
||||
action = {
|
||||
"action": info["action_intervention"].cpu().squeeze(0).float()
|
||||
}
|
||||
action = {"action": info["action_intervention"].cpu().squeeze(0).float()}
|
||||
|
||||
# Process observation for dataset
|
||||
obs = {k: v.cpu().squeeze(0).float() for k, v in obs.items()}
|
||||
@@ -1357,9 +1266,7 @@ def replay_episode(env, repo_id, root=None, episode=0):
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
local_files_only = root is not None
|
||||
dataset = LeRobotDataset(
|
||||
repo_id, root=root, episodes=[episode], local_files_only=local_files_only
|
||||
)
|
||||
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
|
||||
env.reset()
|
||||
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
@@ -1414,9 +1321,7 @@ if __name__ == "__main__":
|
||||
default=None,
|
||||
help="Path to a yaml config file that is necessary to build the reward classifier model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--env-path", type=str, default=None, help="Path to the env yaml file"
|
||||
)
|
||||
parser.add_argument("--env-path", type=str, default=None, help="Path to the env yaml file")
|
||||
parser.add_argument(
|
||||
"--env-overrides",
|
||||
type=str,
|
||||
@@ -1441,12 +1346,8 @@ if __name__ == "__main__":
|
||||
default=None,
|
||||
help="Repo ID of the episode to replay",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-root", type=str, default=None, help="Root of the dataset to replay"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--replay-episode", type=int, default=0, help="Episode to replay"
|
||||
)
|
||||
parser.add_argument("--dataset-root", type=str, default=None, help="Root of the dataset to replay")
|
||||
parser.add_argument("--replay-episode", type=int, default=0, help="Episode to replay")
|
||||
parser.add_argument(
|
||||
"--record-repo-id",
|
||||
type=str,
|
||||
@@ -1534,9 +1435,7 @@ if __name__ == "__main__":
|
||||
smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
|
||||
|
||||
# Execute the step: wrap the NumPy action in a torch tensor.
|
||||
obs, reward, terminated, truncated, info = env.step(
|
||||
(torch.from_numpy(smoothed_action), False)
|
||||
)
|
||||
obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False))
|
||||
if terminated or truncated:
|
||||
sucesses.append(reward)
|
||||
env.reset()
|
||||
|
||||
@@ -23,11 +23,7 @@ def screw_axis_to_transform(S, theta):
|
||||
elif np.linalg.norm(S_w) == 1: # Rotation and translation
|
||||
w_hat = skew_symmetric(S_w)
|
||||
R = np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat
|
||||
t = (
|
||||
np.eye(3) * theta
|
||||
+ (1 - np.cos(theta)) * w_hat
|
||||
+ (theta - np.sin(theta)) * w_hat @ w_hat
|
||||
) @ S_v
|
||||
t = (np.eye(3) * theta + (1 - np.cos(theta)) * w_hat + (theta - np.sin(theta)) * w_hat @ w_hat) @ S_v
|
||||
T = np.eye(4)
|
||||
T[:3, :3] = R
|
||||
T[:3, 3] = t
|
||||
@@ -189,9 +185,7 @@ class RobotKinematics:
|
||||
|
||||
# Wrist
|
||||
# Screw axis of wrist frame wrt base frame
|
||||
self.S_BR = np.array(
|
||||
[0, 1, 0, -self.measurements["wrist"][2], 0, self.measurements["wrist"][0]]
|
||||
)
|
||||
self.S_BR = np.array([0, 1, 0, -self.measurements["wrist"][2], 0, self.measurements["wrist"][0]])
|
||||
|
||||
# 0-position origin to centroid transform
|
||||
self.X_RoRc = self._create_translation_matrix(x=0.0035, y=-0.002)
|
||||
@@ -284,12 +278,7 @@ class RobotKinematics:
|
||||
def fk_shoulder(self, robot_pos_deg):
|
||||
"""Forward kinematics for the shoulder frame."""
|
||||
robot_pos_rad = robot_pos_deg / 180 * np.pi
|
||||
return (
|
||||
self.X_WoBo
|
||||
@ screw_axis_to_transform(self.S_BS, robot_pos_rad[0])
|
||||
@ self.X_SoSc
|
||||
@ self.X_BS
|
||||
)
|
||||
return self.X_WoBo @ screw_axis_to_transform(self.S_BS, robot_pos_rad[0]) @ self.X_SoSc @ self.X_BS
|
||||
|
||||
def fk_humerus(self, robot_pos_deg):
|
||||
"""Forward kinematics for the humerus frame."""
|
||||
@@ -403,15 +392,12 @@ class RobotKinematics:
|
||||
delta *= 0
|
||||
delta[el_ix] = eps / 2
|
||||
Sdot = (
|
||||
fk_func(robot_pos_deg[:-1] + delta)[:3, 3]
|
||||
- fk_func(robot_pos_deg[:-1] - delta)[:3, 3]
|
||||
fk_func(robot_pos_deg[:-1] + delta)[:3, 3] - fk_func(robot_pos_deg[:-1] - delta)[:3, 3]
|
||||
) / eps
|
||||
jac[:, el_ix] = Sdot
|
||||
return jac
|
||||
|
||||
def ik(
|
||||
self, current_joint_state, desired_ee_pose, position_only=True, fk_func=None
|
||||
):
|
||||
def ik(self, current_joint_state, desired_ee_pose, position_only=True, fk_func=None):
|
||||
"""Inverse kinematics using gradient descent.
|
||||
|
||||
Args:
|
||||
@@ -457,9 +443,7 @@ if __name__ == "__main__":
|
||||
|
||||
# Test 1: Forward kinematics consistency
|
||||
print("Test 1: Forward kinematics consistency")
|
||||
test_angles = np.array(
|
||||
[30, 45, -30, 20, 10, 0]
|
||||
) # Example joint angles in degrees
|
||||
test_angles = np.array([30, 45, -30, 20, 10, 0]) # Example joint angles in degrees
|
||||
|
||||
# Calculate FK for different joints
|
||||
shoulder_pose = robot.fk_shoulder(test_angles)
|
||||
@@ -480,13 +464,9 @@ if __name__ == "__main__":
|
||||
]
|
||||
|
||||
# Check if distances generally increase along the chain
|
||||
is_consistent = all(
|
||||
distances[i] <= distances[i + 1] for i in range(len(distances) - 1)
|
||||
)
|
||||
is_consistent = all(distances[i] <= distances[i + 1] for i in range(len(distances) - 1))
|
||||
print(f" Pose distances from origin: {[round(d, 3) for d in distances]}")
|
||||
print(
|
||||
f" Kinematic chain consistency: {'PASSED' if is_consistent else 'FAILED'}"
|
||||
)
|
||||
print(f" Kinematic chain consistency: {'PASSED' if is_consistent else 'FAILED'}")
|
||||
|
||||
# Test 2: Jacobian computation
|
||||
print("Test 2: Jacobian computation")
|
||||
@@ -498,9 +478,7 @@ if __name__ == "__main__":
|
||||
pos_jacobian_shape_ok = positional_jacobian.shape == (3, 5)
|
||||
|
||||
print(f" Jacobian shape: {'PASSED' if jacobian_shape_ok else 'FAILED'}")
|
||||
print(
|
||||
f" Positional Jacobian shape: {'PASSED' if pos_jacobian_shape_ok else 'FAILED'}"
|
||||
)
|
||||
print(f" Positional Jacobian shape: {'PASSED' if pos_jacobian_shape_ok else 'FAILED'}")
|
||||
|
||||
# Test 3: Inverse kinematics
|
||||
print("Test 3: Inverse kinematics (position only)")
|
||||
|
||||
@@ -17,15 +17,8 @@
|
||||
import logging
|
||||
import shutil
|
||||
import time
|
||||
from pprint import pformat
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
# from torch.multiprocessing import Event, Queue, Process
|
||||
# from threading import Event, Thread
|
||||
# from torch.multiprocessing import Queue, Event
|
||||
from torch.multiprocessing import Queue
|
||||
|
||||
from lerobot.scripts.server.utils import setup_process_handlers
|
||||
from pprint import pformat
|
||||
|
||||
import grpc
|
||||
|
||||
@@ -37,6 +30,11 @@ from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
from torch import nn
|
||||
|
||||
# from torch.multiprocessing import Event, Queue, Process
|
||||
# from threading import Event, Thread
|
||||
# from torch.multiprocessing import Queue, Event
|
||||
from torch.multiprocessing import Queue
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
@@ -55,18 +53,17 @@ from lerobot.common.utils.utils import (
|
||||
set_global_random_state,
|
||||
set_global_seed,
|
||||
)
|
||||
|
||||
from lerobot.scripts.server import learner_service
|
||||
from lerobot.scripts.server.buffer import (
|
||||
ReplayBuffer,
|
||||
concatenate_batch_transitions,
|
||||
move_transition_to_device,
|
||||
move_state_dict_to_device,
|
||||
bytes_to_transitions,
|
||||
state_to_bytes,
|
||||
bytes_to_python_object,
|
||||
bytes_to_transitions,
|
||||
concatenate_batch_transitions,
|
||||
move_state_dict_to_device,
|
||||
move_transition_to_device,
|
||||
state_to_bytes,
|
||||
)
|
||||
|
||||
from lerobot.scripts.server import learner_service
|
||||
from lerobot.scripts.server.utils import setup_process_handlers
|
||||
|
||||
|
||||
def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
|
||||
@@ -81,13 +78,9 @@ def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
|
||||
# if resume == True
|
||||
checkpoint_dir = Logger.get_last_checkpoint_dir(out_dir)
|
||||
if not checkpoint_dir.exists():
|
||||
raise RuntimeError(
|
||||
f"No model checkpoint found in {checkpoint_dir} for resume=True"
|
||||
)
|
||||
raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True")
|
||||
|
||||
checkpoint_cfg_path = str(
|
||||
Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml"
|
||||
)
|
||||
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
|
||||
logging.info(
|
||||
colored(
|
||||
"Resume=True detected, resuming previous run",
|
||||
@@ -136,9 +129,7 @@ def load_training_state(
|
||||
|
||||
|
||||
def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
|
||||
num_learnable_params = sum(
|
||||
p.numel() for p in policy.parameters() if p.requires_grad
|
||||
)
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
log_output_dir(out_dir)
|
||||
@@ -210,22 +201,15 @@ def initialize_offline_replay_buffer(
|
||||
def get_observation_features(
|
||||
policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
||||
if (
|
||||
policy.config.vision_encoder_name is None
|
||||
or not policy.config.freeze_vision_encoder
|
||||
):
|
||||
if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder:
|
||||
return None, None
|
||||
|
||||
with torch.no_grad():
|
||||
observation_features = (
|
||||
policy.actor.encoder(observations)
|
||||
if policy.actor.encoder is not None
|
||||
else None
|
||||
policy.actor.encoder(observations) if policy.actor.encoder is not None else None
|
||||
)
|
||||
next_observation_features = (
|
||||
policy.actor.encoder(next_observations)
|
||||
if policy.actor.encoder is not None
|
||||
else None
|
||||
policy.actor.encoder(next_observations) if policy.actor.encoder is not None else None
|
||||
)
|
||||
|
||||
return observation_features, next_observation_features
|
||||
@@ -452,9 +436,7 @@ def add_actor_information_and_train(
|
||||
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||
# Hack: But if we do online traning, we do not need dataset_stats
|
||||
dataset_stats=None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir)
|
||||
if cfg.resume
|
||||
else None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||
)
|
||||
|
||||
# Update the policy config with the grad_clip_norm value from training config if it exists
|
||||
@@ -469,9 +451,7 @@ def add_actor_information_and_train(
|
||||
last_time_policy_pushed = time.time()
|
||||
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(
|
||||
cfg, logger, optimizers
|
||||
)
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(cfg, logger, optimizers)
|
||||
|
||||
log_training_info(cfg, out_dir, policy)
|
||||
|
||||
@@ -483,9 +463,7 @@ def add_actor_information_and_train(
|
||||
active_action_dims = None
|
||||
if cfg.env.wrapper.joint_masking_action_space is not None:
|
||||
active_action_dims = [
|
||||
i
|
||||
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
|
||||
if mask
|
||||
i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask
|
||||
]
|
||||
offline_replay_buffer = initialize_offline_replay_buffer(
|
||||
cfg=cfg,
|
||||
@@ -502,12 +480,8 @@ def add_actor_information_and_train(
|
||||
time.time()
|
||||
logging.info("Starting learner thread")
|
||||
interaction_message, transition = None, None
|
||||
optimization_step = (
|
||||
resume_optimization_step if resume_optimization_step is not None else 0
|
||||
)
|
||||
interaction_step_shift = (
|
||||
resume_interaction_step if resume_interaction_step is not None else 0
|
||||
)
|
||||
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
|
||||
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
|
||||
|
||||
# Extract variables from cfg
|
||||
online_step_before_learning = cfg.training.online_step_before_learning
|
||||
@@ -519,9 +493,7 @@ def add_actor_information_and_train(
|
||||
device = cfg.device
|
||||
storage_device = cfg.training.storage_device
|
||||
policy_update_freq = cfg.training.policy_update_freq
|
||||
policy_parameters_push_frequency = (
|
||||
cfg.actor_learner_config.policy_parameters_push_frequency
|
||||
)
|
||||
policy_parameters_push_frequency = cfg.actor_learner_config.policy_parameters_push_frequency
|
||||
save_checkpoint = cfg.training.save_checkpoint
|
||||
online_steps = cfg.training.online_steps
|
||||
|
||||
@@ -544,9 +516,9 @@ def add_actor_information_and_train(
|
||||
continue
|
||||
replay_buffer.add(**transition)
|
||||
|
||||
if cfg.dataset_repo_id is not None and transition.get(
|
||||
"complementary_info", {}
|
||||
).get("is_intervention"):
|
||||
if cfg.dataset_repo_id is not None and transition.get("complementary_info", {}).get(
|
||||
"is_intervention"
|
||||
):
|
||||
offline_replay_buffer.add(**transition)
|
||||
|
||||
logging.debug("[LEARNER] Received transitions")
|
||||
@@ -556,9 +528,7 @@ def add_actor_information_and_train(
|
||||
interaction_message = bytes_to_python_object(interaction_message)
|
||||
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
|
||||
interaction_message["Interaction step"] += interaction_step_shift
|
||||
logger.log_dict(
|
||||
interaction_message, mode="train", custom_step_key="Interaction step"
|
||||
)
|
||||
logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step")
|
||||
|
||||
logging.debug("[LEARNER] Received interactions")
|
||||
|
||||
@@ -579,9 +549,7 @@ def add_actor_information_and_train(
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
check_nan_in_transition(
|
||||
observations=observations, actions=actions, next_state=next_observations
|
||||
)
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy, observations, next_observations
|
||||
@@ -619,9 +587,7 @@ def add_actor_information_and_train(
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
|
||||
check_nan_in_transition(
|
||||
observations=observations, actions=actions, next_state=next_observations
|
||||
)
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy, observations, next_observations
|
||||
@@ -697,23 +663,15 @@ def add_actor_information_and_train(
|
||||
if optimization_step % log_freq == 0:
|
||||
training_infos["replay_buffer_size"] = len(replay_buffer)
|
||||
if offline_replay_buffer is not None:
|
||||
training_infos["offline_replay_buffer_size"] = len(
|
||||
offline_replay_buffer
|
||||
)
|
||||
training_infos["offline_replay_buffer_size"] = len(offline_replay_buffer)
|
||||
training_infos["Optimization step"] = optimization_step
|
||||
logger.log_dict(
|
||||
d=training_infos, mode="train", custom_step_key="Optimization step"
|
||||
)
|
||||
logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step")
|
||||
# logging.info(f"Training infos: {training_infos}")
|
||||
|
||||
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
|
||||
frequency_for_one_optimization_step = 1 / (
|
||||
time_for_one_optimization_step + 1e-9
|
||||
)
|
||||
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
|
||||
|
||||
logging.info(
|
||||
f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}"
|
||||
)
|
||||
logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
|
||||
|
||||
logger.log_dict(
|
||||
{
|
||||
@@ -728,16 +686,12 @@ def add_actor_information_and_train(
|
||||
if optimization_step % log_freq == 0:
|
||||
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
|
||||
|
||||
if save_checkpoint and (
|
||||
optimization_step % save_freq == 0 or optimization_step == online_steps
|
||||
):
|
||||
if save_checkpoint and (optimization_step % save_freq == 0 or optimization_step == online_steps):
|
||||
logging.info(f"Checkpoint policy after step {optimization_step}")
|
||||
_num_digits = max(6, len(str(online_steps)))
|
||||
step_identifier = f"{optimization_step:0{_num_digits}d}"
|
||||
interaction_step = (
|
||||
interaction_message["Interaction step"]
|
||||
if interaction_message is not None
|
||||
else 0
|
||||
interaction_message["Interaction step"] if interaction_message is not None else 0
|
||||
)
|
||||
logger.save_checkpoint(
|
||||
optimization_step,
|
||||
@@ -755,9 +709,7 @@ def add_actor_information_and_train(
|
||||
shutil.rmtree(
|
||||
dataset_dir,
|
||||
)
|
||||
replay_buffer.to_lerobot_dataset(
|
||||
dataset_repo_id, fps=fps, root=logger.log_dir / "dataset"
|
||||
)
|
||||
replay_buffer.to_lerobot_dataset(dataset_repo_id, fps=fps, root=logger.log_dir / "dataset")
|
||||
if offline_replay_buffer is not None:
|
||||
dataset_dir = logger.log_dir / "dataset_offline"
|
||||
|
||||
@@ -809,9 +761,7 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
||||
optimizer_critic = torch.optim.Adam(
|
||||
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(
|
||||
params=[policy.log_alpha], lr=policy.config.critic_lr
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
|
||||
lr_scheduler = None
|
||||
optimizers = {
|
||||
"actor": optimizer_actor,
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import hilserl_pb2 # type: ignore
|
||||
import hilserl_pb2_grpc # type: ignore
|
||||
import logging
|
||||
from multiprocessing import Event, Queue
|
||||
|
||||
from lerobot.scripts.server.network_utils import receive_bytes_in_chunks
|
||||
from lerobot.scripts.server.network_utils import send_bytes_in_chunks
|
||||
import hilserl_pb2 # type: ignore
|
||||
import hilserl_pb2_grpc # type: ignore
|
||||
|
||||
from lerobot.scripts.server.network_utils import receive_bytes_in_chunks, send_bytes_in_chunks
|
||||
|
||||
MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB
|
||||
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
|
||||
@@ -64,9 +64,7 @@ class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer):
|
||||
|
||||
def SendInteractions(self, request_iterator, _context):
|
||||
# TODO: authorize the request
|
||||
logging.info(
|
||||
"[LEARNER] Received request to receive interactions from the Actor"
|
||||
)
|
||||
logging.info("[LEARNER] Received request to receive interactions from the Actor")
|
||||
|
||||
receive_bytes_in_chunks(
|
||||
request_iterator,
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import einops
|
||||
import numpy as np
|
||||
import gymnasium as gym
|
||||
import torch
|
||||
|
||||
from omegaconf import DictConfig
|
||||
from typing import Any
|
||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||
|
||||
import einops
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from mani_skill.utils.wrappers.record import RecordEpisode
|
||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
def preprocess_maniskill_observation(
|
||||
@@ -63,9 +63,7 @@ class ManiSkillCompat(gym.Wrapper):
|
||||
new_action_space_shape = env.action_space.shape[-1]
|
||||
new_low = np.squeeze(env.action_space.low, axis=0)
|
||||
new_high = np.squeeze(env.action_space.high, axis=0)
|
||||
self.action_space = gym.spaces.Box(
|
||||
low=new_low, high=new_high, shape=(new_action_space_shape,)
|
||||
)
|
||||
self.action_space = gym.spaces.Box(low=new_low, high=new_high, shape=(new_action_space_shape,))
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
@@ -84,9 +82,7 @@ class ManiSkillCompat(gym.Wrapper):
|
||||
class ManiSkillActionWrapper(gym.ActionWrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
self.action_space = gym.spaces.Tuple(
|
||||
spaces=(env.action_space, gym.spaces.Discrete(2))
|
||||
)
|
||||
self.action_space = gym.spaces.Tuple(spaces=(env.action_space, gym.spaces.Discrete(2)))
|
||||
|
||||
def action(self, action):
|
||||
action, telop = action
|
||||
@@ -100,9 +96,7 @@ class ManiSkillMultiplyActionWrapper(gym.Wrapper):
|
||||
action_space_agent: gym.spaces.Box = env.action_space[0]
|
||||
action_space_agent.low = action_space_agent.low * multiply_factor
|
||||
action_space_agent.high = action_space_agent.high * multiply_factor
|
||||
self.action_space = gym.spaces.Tuple(
|
||||
spaces=(action_space_agent, gym.spaces.Discrete(2))
|
||||
)
|
||||
self.action_space = gym.spaces.Tuple(spaces=(action_space_agent, gym.spaces.Discrete(2)))
|
||||
|
||||
def step(self, action):
|
||||
if isinstance(action, tuple):
|
||||
@@ -153,9 +147,7 @@ def make_maniskill(
|
||||
)
|
||||
env = ManiSkillObservationWrapper(env, device=cfg.env.device)
|
||||
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
|
||||
env._max_episode_steps = env.max_episode_steps = (
|
||||
50 # gym_utils.find_max_episode_steps_value(env)
|
||||
)
|
||||
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
|
||||
env.unwrapped.metadata["render_fps"] = 20
|
||||
env = ManiSkillCompat(env)
|
||||
env = ManiSkillActionWrapper(env)
|
||||
@@ -166,12 +158,11 @@ def make_maniskill(
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
import hydra
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--config", type=str, default="lerobot/configs/env/maniskill_example.yaml"
|
||||
)
|
||||
parser.add_argument("--config", type=str, default="lerobot/configs/env/maniskill_example.yaml")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize config
|
||||
|
||||
@@ -15,12 +15,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.scripts.server import hilserl_pb2
|
||||
import logging
|
||||
import io
|
||||
from multiprocessing import Queue, Event
|
||||
import logging
|
||||
from multiprocessing import Event, Queue
|
||||
from typing import Any
|
||||
|
||||
from lerobot.scripts.server import hilserl_pb2
|
||||
|
||||
CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB
|
||||
|
||||
|
||||
@@ -31,9 +32,7 @@ def bytes_buffer_size(buffer: io.BytesIO) -> int:
|
||||
return result
|
||||
|
||||
|
||||
def send_bytes_in_chunks(
|
||||
buffer: bytes, message_class: Any, log_prefix: str = "", silent: bool = True
|
||||
):
|
||||
def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = "", silent: bool = True):
|
||||
buffer = io.BytesIO(buffer)
|
||||
size_in_bytes = bytes_buffer_size(buffer)
|
||||
|
||||
@@ -56,16 +55,12 @@ def send_bytes_in_chunks(
|
||||
|
||||
yield message_class(transfer_state=transfer_state, data=chunk)
|
||||
sent_bytes += size_to_read
|
||||
logging_method(
|
||||
f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}"
|
||||
)
|
||||
logging_method(f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}")
|
||||
|
||||
logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB")
|
||||
|
||||
|
||||
def receive_bytes_in_chunks(
|
||||
iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""
|
||||
):
|
||||
def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""):
|
||||
bytes_buffer = io.BytesIO()
|
||||
step = 0
|
||||
|
||||
@@ -89,9 +84,7 @@ def receive_bytes_in_chunks(
|
||||
logging.debug(f"{log_prefix} Received data at step {step}")
|
||||
elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_END:
|
||||
bytes_buffer.write(item.data)
|
||||
logging.debug(
|
||||
f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}"
|
||||
)
|
||||
logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}")
|
||||
|
||||
queue.put(bytes_buffer.getvalue())
|
||||
|
||||
|
||||
@@ -18,9 +18,10 @@
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
from torch.multiprocessing import Queue
|
||||
from queue import Empty
|
||||
|
||||
from torch.multiprocessing import Queue
|
||||
|
||||
shutdown_event_counter = 0
|
||||
|
||||
|
||||
|
||||
@@ -223,18 +223,12 @@ def train(cfg: TrainPipelineConfig):
|
||||
step = 0 # number of policy updates (forward + backward + optim)
|
||||
|
||||
if cfg.resume:
|
||||
step, optimizer, lr_scheduler = load_training_state(
|
||||
cfg.checkpoint_path, optimizer, lr_scheduler
|
||||
)
|
||||
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
|
||||
|
||||
num_learnable_params = sum(
|
||||
p.numel() for p in policy.parameters() if p.requires_grad
|
||||
)
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
logging.info(
|
||||
colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}"
|
||||
)
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
||||
if cfg.env is not None:
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
|
||||
@@ -335,9 +329,7 @@ def train(cfg: TrainPipelineConfig):
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
with (
|
||||
torch.no_grad(),
|
||||
torch.autocast(device_type=device.type)
|
||||
if cfg.policy.use_amp
|
||||
else nullcontext(),
|
||||
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
|
||||
):
|
||||
eval_info = eval_policy(
|
||||
eval_env,
|
||||
|
||||
@@ -52,19 +52,13 @@ def get_model(cfg, logger): # noqa I001
|
||||
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
||||
model = Classifier(classifier_config)
|
||||
if cfg.resume:
|
||||
model.load_state_dict(
|
||||
Classifier.from_pretrained(
|
||||
str(logger.last_pretrained_model_dir)
|
||||
).state_dict()
|
||||
)
|
||||
model.load_state_dict(Classifier.from_pretrained(str(logger.last_pretrained_model_dir)).state_dict())
|
||||
return model
|
||||
|
||||
|
||||
def create_balanced_sampler(dataset, cfg):
|
||||
# Get underlying dataset if using Subset
|
||||
original_dataset = (
|
||||
dataset.dataset if isinstance(dataset, torch.utils.data.Subset) else dataset
|
||||
)
|
||||
original_dataset = dataset.dataset if isinstance(dataset, torch.utils.data.Subset) else dataset
|
||||
|
||||
# Get indices if using Subset (for slicing)
|
||||
indices = dataset.indices if isinstance(dataset, torch.utils.data.Subset) else None
|
||||
@@ -83,9 +77,7 @@ def create_balanced_sampler(dataset, cfg):
|
||||
class_weights = 1.0 / counts.float()
|
||||
sample_weights = class_weights[labels]
|
||||
|
||||
return WeightedRandomSampler(
|
||||
weights=sample_weights, num_samples=len(sample_weights), replacement=True
|
||||
)
|
||||
return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
|
||||
|
||||
|
||||
def support_amp(device: torch.device, cfg: DictConfig) -> bool:
|
||||
@@ -94,9 +86,7 @@ def support_amp(device: torch.device, cfg: DictConfig) -> bool:
|
||||
return cfg.training.use_amp and device.type in ("cuda", "cpu")
|
||||
|
||||
|
||||
def train_epoch(
|
||||
model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg
|
||||
):
|
||||
def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg):
|
||||
# Single epoch training loop with AMP support and progress tracking
|
||||
model.train()
|
||||
correct = 0
|
||||
@@ -110,11 +100,7 @@ def train_epoch(
|
||||
labels = batch[cfg.training.label_key].float().to(device)
|
||||
|
||||
# Forward pass with optional AMP
|
||||
with (
|
||||
torch.autocast(device_type=device.type)
|
||||
if support_amp(device, cfg)
|
||||
else nullcontext()
|
||||
):
|
||||
with torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext():
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs.logits, labels)
|
||||
|
||||
@@ -159,9 +145,7 @@ def validate(model, val_loader, criterion, device, logger, cfg):
|
||||
|
||||
with (
|
||||
torch.no_grad(),
|
||||
torch.autocast(device_type=device.type)
|
||||
if support_amp(device, cfg)
|
||||
else nullcontext(),
|
||||
torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(),
|
||||
):
|
||||
for batch in tqdm(val_loader, desc="Validation"):
|
||||
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
|
||||
@@ -174,9 +158,7 @@ def validate(model, val_loader, criterion, device, logger, cfg):
|
||||
):
|
||||
outputs = model(images)
|
||||
inference_times.append(
|
||||
next(
|
||||
x for x in prof.key_averages() if x.key == "model_inference"
|
||||
).cpu_time
|
||||
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
|
||||
)
|
||||
else:
|
||||
outputs = model(images)
|
||||
@@ -194,24 +176,16 @@ def validate(model, val_loader, criterion, device, logger, cfg):
|
||||
|
||||
# Log sample predictions for visualization
|
||||
if len(samples) < cfg.eval.num_samples_to_log:
|
||||
for i in range(
|
||||
min(cfg.eval.num_samples_to_log - len(samples), len(images))
|
||||
):
|
||||
for i in range(min(cfg.eval.num_samples_to_log - len(samples), len(images))):
|
||||
if model.config.num_classes == 2:
|
||||
confidence = round(outputs.probabilities[i].item(), 3)
|
||||
else:
|
||||
confidence = [
|
||||
round(prob, 3) for prob in outputs.probabilities[i].tolist()
|
||||
]
|
||||
confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()]
|
||||
samples.append(
|
||||
{
|
||||
**{
|
||||
f"image_{img_key}": wandb.Image(
|
||||
images[img_idx][i].cpu()
|
||||
)
|
||||
for img_idx, img_key in enumerate(
|
||||
cfg.training.image_keys
|
||||
)
|
||||
f"image_{img_key}": wandb.Image(images[img_idx][i].cpu())
|
||||
for img_idx, img_key in enumerate(cfg.training.image_keys)
|
||||
},
|
||||
"true_label": labels[i].item(),
|
||||
"predicted": predictions[i].item(),
|
||||
@@ -286,9 +260,7 @@ def benchmark_inference_time(model, dataset, logger, cfg, device, step):
|
||||
_ = model(x)
|
||||
|
||||
inference_times.append(
|
||||
next(
|
||||
x for x in prof.key_averages() if x.key == "model_inference"
|
||||
).cpu_time
|
||||
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
|
||||
)
|
||||
|
||||
inference_times = np.array(inference_times)
|
||||
@@ -314,9 +286,7 @@ def benchmark_inference_time(model, dataset, logger, cfg, device, step):
|
||||
return avg, median, std
|
||||
|
||||
|
||||
def train(
|
||||
cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None
|
||||
) -> None:
|
||||
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None) -> None:
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
if job_name is None:
|
||||
@@ -372,9 +342,7 @@ def train(
|
||||
"You have set resume=True, but there is no model checkpoint in "
|
||||
f"{Logger.get_last_checkpoint_dir(out_dir)}"
|
||||
)
|
||||
checkpoint_cfg_path = str(
|
||||
Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml"
|
||||
)
|
||||
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
|
||||
logging.info(
|
||||
colored(
|
||||
"You have set resume=True, indicating that you wish to resume a run",
|
||||
@@ -387,9 +355,7 @@ def train(
|
||||
# Check for differences between the checkpoint configuration and provided configuration.
|
||||
# Hack to resolve the delta_timestamps ahead of time in order to properly diff.
|
||||
resolve_delta_timestamps(cfg)
|
||||
diff = DeepDiff(
|
||||
OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg)
|
||||
)
|
||||
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
|
||||
# Ignore the `resume` and parameters.
|
||||
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
|
||||
del diff["values_changed"]["root['resume']"]
|
||||
@@ -408,11 +374,7 @@ def train(
|
||||
|
||||
optimizer = optim.AdamW(model.parameters(), lr=cfg.training.learning_rate)
|
||||
# Use BCEWithLogitsLoss for binary classification and CrossEntropyLoss for multi-class
|
||||
criterion = (
|
||||
nn.BCEWithLogitsLoss()
|
||||
if model.config.num_classes == 2
|
||||
else nn.CrossEntropyLoss()
|
||||
)
|
||||
criterion = nn.BCEWithLogitsLoss() if model.config.num_classes == 2 else nn.CrossEntropyLoss()
|
||||
grad_scaler = GradScaler(enabled=cfg.training.use_amp)
|
||||
|
||||
# Log model parameters
|
||||
|
||||
@@ -52,9 +52,7 @@ def make_optimizers_and_scheduler(cfg, policy):
|
||||
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
|
||||
)
|
||||
# We wrap policy log temperature in list because this is a torch tensor and not a nn.Module
|
||||
optimizer_temperature = torch.optim.Adam(
|
||||
params=[policy.log_alpha], lr=policy.config.critic_lr
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
|
||||
lr_scheduler = None
|
||||
optimizers = {
|
||||
"actor": optimizer_actor,
|
||||
@@ -106,9 +104,7 @@ def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Te
|
||||
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
|
||||
|
||||
# Gather pixels
|
||||
cropped_hwcn = images_hwcn[
|
||||
torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :
|
||||
]
|
||||
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
|
||||
# cropped_hwcn => (B, crop_h, crop_w, C)
|
||||
|
||||
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
|
||||
@@ -198,12 +194,8 @@ class ReplayBuffer:
|
||||
"""
|
||||
# We convert the LeRobotDataset into a replay buffer, because it is more efficient to sample from
|
||||
# a replay buffer than from a lerobot dataset.
|
||||
replay_buffer = cls(
|
||||
capacity=len(lerobot_dataset), device=device, state_keys=state_keys
|
||||
)
|
||||
list_transition = cls._lerobotdataset_to_transitions(
|
||||
dataset=lerobot_dataset, state_keys=state_keys
|
||||
)
|
||||
replay_buffer = cls(capacity=len(lerobot_dataset), device=device, state_keys=state_keys)
|
||||
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
|
||||
# Fill the replay buffer with the lerobot dataset transitions
|
||||
for data in list_transition:
|
||||
replay_buffer.add(
|
||||
@@ -248,9 +240,7 @@ class ReplayBuffer:
|
||||
|
||||
# If not provided, you can either raise an error or define a default:
|
||||
if state_keys is None:
|
||||
raise ValueError(
|
||||
"You must provide a list of keys in `state_keys` that define your 'state'."
|
||||
)
|
||||
raise ValueError("You must provide a list of keys in `state_keys` that define your 'state'.")
|
||||
|
||||
transitions: list[Transition] = []
|
||||
num_frames = len(dataset)
|
||||
@@ -304,40 +294,36 @@ class ReplayBuffer:
|
||||
# -- Build batched states --
|
||||
batch_state = {}
|
||||
for key in self.state_keys:
|
||||
batch_state[key] = torch.cat(
|
||||
[t["state"][key] for t in list_of_transitions], dim=0
|
||||
).to(self.device)
|
||||
batch_state[key] = torch.cat([t["state"][key] for t in list_of_transitions], dim=0).to(
|
||||
self.device
|
||||
)
|
||||
if key.startswith("observation.image") and self.use_drq:
|
||||
batch_state[key] = self.image_augmentation_function(batch_state[key])
|
||||
|
||||
# -- Build batched actions --
|
||||
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(
|
||||
self.device
|
||||
)
|
||||
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(self.device)
|
||||
|
||||
# -- Build batched rewards --
|
||||
batch_rewards = torch.tensor(
|
||||
[t["reward"] for t in list_of_transitions], dtype=torch.float32
|
||||
).to(self.device)
|
||||
batch_rewards = torch.tensor([t["reward"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
|
||||
# -- Build batched next states --
|
||||
batch_next_state = {}
|
||||
for key in self.state_keys:
|
||||
batch_next_state[key] = torch.cat(
|
||||
[t["next_state"][key] for t in list_of_transitions], dim=0
|
||||
).to(self.device)
|
||||
batch_next_state[key] = torch.cat([t["next_state"][key] for t in list_of_transitions], dim=0).to(
|
||||
self.device
|
||||
)
|
||||
if key.startswith("observation.image") and self.use_drq:
|
||||
batch_next_state[key] = self.image_augmentation_function(
|
||||
batch_next_state[key]
|
||||
)
|
||||
batch_next_state[key] = self.image_augmentation_function(batch_next_state[key])
|
||||
|
||||
# -- Build batched dones --
|
||||
batch_dones = torch.tensor(
|
||||
[t["done"] for t in list_of_transitions], dtype=torch.float32
|
||||
).to(self.device)
|
||||
batch_dones = torch.tensor(
|
||||
[t["done"] for t in list_of_transitions], dtype=torch.float32
|
||||
).to(self.device)
|
||||
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
|
||||
# Return a BatchTransition typed dict
|
||||
return BatchTransition(
|
||||
@@ -427,9 +413,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||
# Hack: But if we do online traning, we do not need dataset_stats
|
||||
dataset_stats=None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir)
|
||||
if cfg.resume
|
||||
else None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||
device=device,
|
||||
)
|
||||
assert isinstance(policy, nn.Module)
|
||||
@@ -438,9 +422,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
|
||||
# TODO: Handle resume
|
||||
|
||||
num_learnable_params = sum(
|
||||
p.numel() for p in policy.parameters() if p.requires_grad
|
||||
)
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
log_output_dir(out_dir)
|
||||
@@ -481,16 +463,12 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
|
||||
if interaction_step >= cfg.training.online_step_before_learning:
|
||||
action = policy.select_action(batch=obs)
|
||||
next_obs, reward, done, truncated, info = online_env.step(
|
||||
action.cpu().numpy()
|
||||
)
|
||||
next_obs, reward, done, truncated, info = online_env.step(action.cpu().numpy())
|
||||
else:
|
||||
action = online_env.action_space.sample()
|
||||
next_obs, reward, done, truncated, info = online_env.step(action)
|
||||
# HACK
|
||||
action = torch.tensor(action, dtype=torch.float32).to(
|
||||
device, non_blocking=True
|
||||
)
|
||||
action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True)
|
||||
|
||||
# HACK: For maniskill
|
||||
# next_obs = preprocess_observation(next_obs)
|
||||
@@ -500,20 +478,14 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
# Because we are using a single environment
|
||||
# we can safely assume that the episode is done
|
||||
if done[0] or truncated[0]:
|
||||
logging.info(
|
||||
f"Global step {interaction_step}: Episode reward: {sum_reward_episode}"
|
||||
)
|
||||
logger.log_dict(
|
||||
{"Sum episode reward": sum_reward_episode}, interaction_step
|
||||
)
|
||||
logging.info(f"Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||
logger.log_dict({"Sum episode reward": sum_reward_episode}, interaction_step)
|
||||
sum_reward_episode = 0
|
||||
# HACK: This is for maniskill
|
||||
logging.info(
|
||||
f"global step {interaction_step}: episode success: {info['success'].float().item()} \n"
|
||||
)
|
||||
logger.log_dict(
|
||||
{"Episode success": info["success"].float().item()}, interaction_step
|
||||
)
|
||||
logger.log_dict({"Episode success": info["success"].float().item()}, interaction_step)
|
||||
|
||||
replay_buffer.add(
|
||||
state=obs,
|
||||
@@ -587,9 +559,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
|
||||
training_infos["loss_actor"] = loss_actor.item()
|
||||
|
||||
loss_temperature = policy.compute_loss_temperature(
|
||||
observations=observations
|
||||
)
|
||||
loss_temperature = policy.compute_loss_temperature(observations=observations)
|
||||
optimizers["temperature"].zero_grad()
|
||||
loss_temperature.backward()
|
||||
optimizers["temperature"].step()
|
||||
@@ -611,9 +581,7 @@ def train_cli(cfg: dict):
|
||||
)
|
||||
|
||||
|
||||
def train_notebook(
|
||||
out_dir=None, job_name=None, config_name="default", config_path="../configs"
|
||||
):
|
||||
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
|
||||
from hydra import compose, initialize
|
||||
|
||||
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
||||
|
||||
@@ -94,12 +94,8 @@ def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
|
||||
assert chw_float32_torch.dtype == torch.float32
|
||||
assert chw_float32_torch.ndim == 3
|
||||
c, h, w = chw_float32_torch.shape
|
||||
assert c < h and c < w, (
|
||||
f"expect channel first images, but instead {chw_float32_torch.shape}"
|
||||
)
|
||||
hwc_uint8_numpy = (
|
||||
(chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
|
||||
)
|
||||
assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}"
|
||||
hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
|
||||
return hwc_uint8_numpy
|
||||
|
||||
|
||||
|
||||
@@ -142,12 +142,8 @@ def run_server(
|
||||
)
|
||||
)
|
||||
|
||||
@app.route(
|
||||
"/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>"
|
||||
)
|
||||
def show_episode(
|
||||
dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes
|
||||
):
|
||||
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
|
||||
def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes):
|
||||
repo_id = f"{dataset_namespace}/{dataset_name}"
|
||||
try:
|
||||
if dataset is None:
|
||||
@@ -158,9 +154,7 @@ def run_server(
|
||||
400,
|
||||
)
|
||||
dataset_version = (
|
||||
str(dataset.meta._version)
|
||||
if isinstance(dataset, LeRobotDataset)
|
||||
else dataset.codebase_version
|
||||
str(dataset.meta._version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
|
||||
)
|
||||
match = re.search(r"v(\d+)\.", dataset_version)
|
||||
if match:
|
||||
@@ -168,9 +162,7 @@ def run_server(
|
||||
if major_version < 2:
|
||||
return "Make sure to convert your LeRobotDataset to v2 & above."
|
||||
|
||||
episode_data_csv_str, columns, ignored_columns = get_episode_data(
|
||||
dataset, episode_id
|
||||
)
|
||||
episode_data_csv_str, columns, ignored_columns = get_episode_data(dataset, episode_id)
|
||||
dataset_info = {
|
||||
"repo_id": f"{dataset_namespace}/{dataset_name}",
|
||||
"num_samples": dataset.num_frames
|
||||
@@ -183,8 +175,7 @@ def run_server(
|
||||
}
|
||||
if isinstance(dataset, LeRobotDataset):
|
||||
video_paths = [
|
||||
dataset.meta.get_video_file_path(episode_id, key)
|
||||
for key in dataset.meta.video_keys
|
||||
dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys
|
||||
]
|
||||
videos_info = [
|
||||
{
|
||||
@@ -197,9 +188,7 @@ def run_server(
|
||||
]
|
||||
tasks = dataset.meta.episodes[episode_id]["tasks"]
|
||||
else:
|
||||
video_keys = [
|
||||
key for key, ft in dataset.features.items() if ft["dtype"] == "video"
|
||||
]
|
||||
video_keys = [key for key, ft in dataset.features.items() if ft["dtype"] == "video"]
|
||||
videos_info = [
|
||||
{
|
||||
"url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
|
||||
@@ -219,24 +208,16 @@ def run_server(
|
||||
)
|
||||
response.raise_for_status()
|
||||
# Split into lines and parse each line as JSON
|
||||
tasks_jsonl = [
|
||||
json.loads(line) for line in response.text.splitlines() if line.strip()
|
||||
]
|
||||
tasks_jsonl = [json.loads(line) for line in response.text.splitlines() if line.strip()]
|
||||
|
||||
filtered_tasks_jsonl = [
|
||||
row for row in tasks_jsonl if row["episode_index"] == episode_id
|
||||
]
|
||||
filtered_tasks_jsonl = [row for row in tasks_jsonl if row["episode_index"] == episode_id]
|
||||
tasks = filtered_tasks_jsonl[0]["tasks"]
|
||||
|
||||
videos_info[0]["language_instruction"] = tasks
|
||||
|
||||
if episodes is None:
|
||||
episodes = list(
|
||||
range(
|
||||
dataset.num_episodes
|
||||
if isinstance(dataset, LeRobotDataset)
|
||||
else dataset.total_episodes
|
||||
)
|
||||
range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes)
|
||||
)
|
||||
|
||||
return render_template(
|
||||
@@ -263,11 +244,7 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
|
||||
This file will be loaded by Dygraph javascript to plot data in real time."""
|
||||
columns = []
|
||||
|
||||
selected_columns = [
|
||||
col
|
||||
for col, ft in dataset.features.items()
|
||||
if ft["dtype"] in ["float32", "int32"]
|
||||
]
|
||||
selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] in ["float32", "int32"]]
|
||||
selected_columns.remove("timestamp")
|
||||
|
||||
ignored_columns = []
|
||||
@@ -288,10 +265,7 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
|
||||
else dataset.features[column_name].shape[0]
|
||||
)
|
||||
|
||||
if (
|
||||
"names" in dataset.features[column_name]
|
||||
and dataset.features[column_name]["names"]
|
||||
):
|
||||
if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]:
|
||||
column_names = dataset.features[column_name]["names"]
|
||||
while not isinstance(column_names, list):
|
||||
column_names = list(column_names.values())[0]
|
||||
@@ -314,12 +288,9 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
|
||||
else:
|
||||
repo_id = dataset.repo_id
|
||||
|
||||
url = (
|
||||
f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
|
||||
+ dataset.data_path.format(
|
||||
episode_chunk=int(episode_index) // dataset.chunks_size,
|
||||
episode_index=episode_index,
|
||||
)
|
||||
url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format(
|
||||
episode_chunk=int(episode_index) // dataset.chunks_size,
|
||||
episode_index=episode_index,
|
||||
)
|
||||
df = pd.read_parquet(url)
|
||||
data = df[selected_columns] # Select specific columns
|
||||
@@ -352,9 +323,7 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
|
||||
]
|
||||
|
||||
|
||||
def get_episode_language_instruction(
|
||||
dataset: LeRobotDataset, ep_index: int
|
||||
) -> list[str]:
|
||||
def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
||||
# check if the dataset has language instructions
|
||||
if "language_instruction" not in dataset.features:
|
||||
return None
|
||||
@@ -365,9 +334,7 @@ def get_episode_language_instruction(
|
||||
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
|
||||
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
|
||||
# with the tf.tensor appearing in the string
|
||||
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix(
|
||||
"', shape=(), dtype=string)"
|
||||
)
|
||||
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)")
|
||||
|
||||
|
||||
def get_dataset_info(repo_id: str) -> IterableNamespace:
|
||||
@@ -403,9 +370,7 @@ def visualize_dataset_html(
|
||||
if force_override:
|
||||
shutil.rmtree(output_dir)
|
||||
else:
|
||||
logging.info(
|
||||
f"Output directory already exists. Loading from it: '{output_dir}'"
|
||||
)
|
||||
logging.info(f"Output directory already exists. Loading from it: '{output_dir}'")
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
@@ -47,9 +47,7 @@ OUTPUT_DIR = Path("outputs/image_transforms")
|
||||
to_pil = ToPILImage()
|
||||
|
||||
|
||||
def save_all_transforms(
|
||||
cfg: ImageTransformsConfig, original_frame, output_dir, n_examples
|
||||
):
|
||||
def save_all_transforms(cfg: ImageTransformsConfig, original_frame, output_dir, n_examples):
|
||||
output_dir_all = output_dir / "all"
|
||||
output_dir_all.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -62,9 +60,7 @@ def save_all_transforms(
|
||||
print(f" {output_dir_all}")
|
||||
|
||||
|
||||
def save_each_transform(
|
||||
cfg: ImageTransformsConfig, original_frame, output_dir, n_examples
|
||||
):
|
||||
def save_each_transform(cfg: ImageTransformsConfig, original_frame, output_dir, n_examples):
|
||||
if not cfg.enable:
|
||||
logging.warning(
|
||||
"No single transforms will be saved, because `image_transforms.enable=False`. To enable, set `enable` to True in `ImageTransformsConfig` or in the command line with `--image_transforms.enable=True`."
|
||||
@@ -93,15 +89,9 @@ def save_each_transform(
|
||||
tf_cfg_kwgs_max[key] = [max_, max_]
|
||||
tf_cfg_kwgs_avg[key] = [avg, avg]
|
||||
|
||||
tf_min = make_transform_from_config(
|
||||
replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_min})
|
||||
)
|
||||
tf_max = make_transform_from_config(
|
||||
replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_max})
|
||||
)
|
||||
tf_avg = make_transform_from_config(
|
||||
replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_avg})
|
||||
)
|
||||
tf_min = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_min}))
|
||||
tf_max = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_max}))
|
||||
tf_avg = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_avg}))
|
||||
|
||||
tf_frame_min = tf_min(original_frame)
|
||||
tf_frame_max = tf_max(original_frame)
|
||||
@@ -115,9 +105,7 @@ def save_each_transform(
|
||||
|
||||
|
||||
@draccus.wrap()
|
||||
def visualize_image_transforms(
|
||||
cfg: DatasetConfig, output_dir: Path = OUTPUT_DIR, n_examples: int = 5
|
||||
):
|
||||
def visualize_image_transforms(cfg: DatasetConfig, output_dir: Path = OUTPUT_DIR, n_examples: int = 5):
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=cfg.repo_id,
|
||||
episodes=cfg.episodes,
|
||||
|
||||
Reference in New Issue
Block a user