[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Michel Aractingi
parent
bb69cb3c8c
commit
85fe8a3f4e
@@ -108,20 +108,26 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
|
||||
break
|
||||
|
||||
if motor_index == -1:
|
||||
raise ValueError("No motors detected. Please ensure you have one motor connected.")
|
||||
raise ValueError(
|
||||
"No motors detected. Please ensure you have one motor connected."
|
||||
)
|
||||
|
||||
print(f"Motor index found at: {motor_index}")
|
||||
|
||||
if brand == "feetech":
|
||||
# Allows ID and BAUDRATE to be written in memory
|
||||
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0)
|
||||
motor_bus.write_with_motor_ids(
|
||||
motor_bus.motor_models, motor_index, "Lock", 0
|
||||
)
|
||||
|
||||
if baudrate != baudrate_des:
|
||||
print(f"Setting its baudrate to {baudrate_des}")
|
||||
baudrate_idx = list(series_baudrate_table.values()).index(baudrate_des)
|
||||
|
||||
# The write can fail, so we allow retries
|
||||
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx)
|
||||
motor_bus.write_with_motor_ids(
|
||||
motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx
|
||||
)
|
||||
time.sleep(0.5)
|
||||
motor_bus.set_bus_baudrate(baudrate_des)
|
||||
present_baudrate_idx = motor_bus.read_with_motor_ids(
|
||||
@@ -136,7 +142,9 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
|
||||
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0)
|
||||
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "ID", motor_idx_des)
|
||||
|
||||
present_idx = motor_bus.read_with_motor_ids(motor_bus.motor_models, motor_idx_des, "ID", num_retry=2)
|
||||
present_idx = motor_bus.read_with_motor_ids(
|
||||
motor_bus.motor_models, motor_idx_des, "ID", num_retry=2
|
||||
)
|
||||
if present_idx != motor_idx_des:
|
||||
raise OSError("Failed to write index.")
|
||||
|
||||
@@ -164,12 +172,29 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--port", type=str, required=True, help="Motors bus port (e.g. dynamixel,feetech)")
|
||||
parser.add_argument("--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)")
|
||||
parser.add_argument("--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)")
|
||||
parser.add_argument("--ID", type=int, required=True, help="Desired ID of the current motor (e.g. 1,2,3)")
|
||||
parser.add_argument(
|
||||
"--baudrate", type=int, default=1000000, help="Desired baudrate for the motor (default: 1000000)"
|
||||
"--port",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Motors bus port (e.g. dynamixel,feetech)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ID",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Desired ID of the current motor (e.g. 1,2,3)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--baudrate",
|
||||
type=int,
|
||||
default=1000000,
|
||||
help="Desired baudrate for the motor (default: 1000000)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
@@ -149,7 +149,11 @@ def init_sim_calibration(robot, cfg):
|
||||
axis_directions = np.array(cfg.get("axis_directions", [1]))
|
||||
offsets = np.array(cfg.get("offsets", [0])) * np.pi
|
||||
|
||||
return {"start_pos": start_pos, "axis_directions": axis_directions, "offsets": offsets}
|
||||
return {
|
||||
"start_pos": start_pos,
|
||||
"axis_directions": axis_directions,
|
||||
"offsets": offsets,
|
||||
}
|
||||
|
||||
|
||||
def real_positions_to_sim(real_positions, axis_directions, start_pos, offsets):
|
||||
@@ -170,7 +174,10 @@ def teleoperate(env, robot: Robot, process_action_fn, teleop_time_s=None):
|
||||
leader_pos = robot.leader_arms.main.read("Present_Position")
|
||||
action = process_action_fn(leader_pos)
|
||||
env.step(np.expand_dims(action, 0))
|
||||
if teleop_time_s is not None and time.perf_counter() - start_teleop_t > teleop_time_s:
|
||||
if (
|
||||
teleop_time_s is not None
|
||||
and time.perf_counter() - start_teleop_t > teleop_time_s
|
||||
):
|
||||
print("Teleoperation processes finished.")
|
||||
break
|
||||
|
||||
@@ -202,19 +209,27 @@ def record(
|
||||
# Load pretrained policy
|
||||
|
||||
extra_features = (
|
||||
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None
|
||||
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}}
|
||||
if assign_rewards
|
||||
else None
|
||||
)
|
||||
|
||||
policy = None
|
||||
if pretrained_policy_name_or_path is not None:
|
||||
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
|
||||
policy, policy_fps, device, use_amp = init_policy(
|
||||
pretrained_policy_name_or_path, policy_overrides
|
||||
)
|
||||
|
||||
if fps is None:
|
||||
fps = policy_fps
|
||||
logging.warning(f"No fps provided, so using the fps from policy config ({policy_fps}).")
|
||||
logging.warning(
|
||||
f"No fps provided, so using the fps from policy config ({policy_fps})."
|
||||
)
|
||||
|
||||
if policy is None and process_action_from_leader is None:
|
||||
raise ValueError("Either policy or process_action_fn has to be set to enable control in sim.")
|
||||
raise ValueError(
|
||||
"Either policy or process_action_fn has to be set to enable control in sim."
|
||||
)
|
||||
|
||||
# initialize listener before sim env
|
||||
listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
|
||||
@@ -256,7 +271,11 @@ def record(
|
||||
"shape": env.observation_space[obs_key].shape,
|
||||
}
|
||||
|
||||
features["action"] = {"dtype": "float32", "shape": env.action_space.shape, "names": None}
|
||||
features["action"] = {
|
||||
"dtype": "float32",
|
||||
"shape": env.action_space.shape,
|
||||
"names": None,
|
||||
}
|
||||
features = {**features, **extra_features}
|
||||
|
||||
# Create empty dataset or load existing saved episodes
|
||||
@@ -357,7 +376,9 @@ def record(
|
||||
if events["stop_recording"] or recorded_episodes >= num_episodes:
|
||||
break
|
||||
else:
|
||||
logging.info("Waiting for a few seconds before starting next episode recording...")
|
||||
logging.info(
|
||||
"Waiting for a few seconds before starting next episode recording..."
|
||||
)
|
||||
busy_wait(3)
|
||||
|
||||
log_say("Stop recording", play_sounds, blocking=True)
|
||||
@@ -375,7 +396,12 @@ def record(
|
||||
|
||||
|
||||
def replay(
|
||||
env, root: Path, repo_id: str, episode: int, fps: int | None = None, local_files_only: bool = True
|
||||
env,
|
||||
root: Path,
|
||||
repo_id: str,
|
||||
episode: int,
|
||||
fps: int | None = None,
|
||||
local_files_only: bool = True,
|
||||
):
|
||||
env = env()
|
||||
|
||||
@@ -422,7 +448,10 @@ if __name__ == "__main__":
|
||||
|
||||
parser_record = subparsers.add_parser("record", parents=[base_parser])
|
||||
parser_record.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
"--fps",
|
||||
type=none_or_int,
|
||||
default=None,
|
||||
help="Frames per second (set to None to disable)",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--root",
|
||||
@@ -448,7 +477,9 @@ if __name__ == "__main__":
|
||||
required=True,
|
||||
help="A description of the task preformed during recording that can be used as a language instruction.",
|
||||
)
|
||||
parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.")
|
||||
parser_record.add_argument(
|
||||
"--num-episodes", type=int, default=50, help="Number of episodes to record."
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--run-compute-stats",
|
||||
type=int,
|
||||
@@ -509,7 +540,10 @@ if __name__ == "__main__":
|
||||
|
||||
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
|
||||
parser_replay.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
"--fps",
|
||||
type=none_or_int,
|
||||
default=None,
|
||||
help="Frames per second (set to None to disable)",
|
||||
)
|
||||
parser_replay.add_argument(
|
||||
"--root",
|
||||
@@ -523,7 +557,9 @@ if __name__ == "__main__":
|
||||
default="lerobot/test",
|
||||
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
|
||||
)
|
||||
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episodes to replay.")
|
||||
parser_replay.add_argument(
|
||||
"--episode", type=int, default=0, help="Index of the episodes to replay."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
@@ -59,7 +59,11 @@ np_version = np.__version__ if HAS_NP else "N/A"
|
||||
|
||||
torch_version = torch.__version__ if HAS_TORCH else "N/A"
|
||||
torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A"
|
||||
cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A"
|
||||
cuda_version = (
|
||||
torch._C._cuda_getCompiledVersion()
|
||||
if HAS_TORCH and torch.version.cuda is not None
|
||||
else "N/A"
|
||||
)
|
||||
|
||||
|
||||
# TODO(aliberts): refactor into an actual command `lerobot env`
|
||||
@@ -77,7 +81,9 @@ def display_sys_info() -> dict:
|
||||
"Using GPU in script?": "<fill in>",
|
||||
# "Using distributed or parallel set-up in script?": "<fill in>",
|
||||
}
|
||||
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n")
|
||||
print(
|
||||
"\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n"
|
||||
)
|
||||
print(format_dict(info))
|
||||
return info
|
||||
|
||||
|
||||
@@ -174,7 +174,10 @@ def rollout(
|
||||
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
|
||||
# available of none of the envs finished.
|
||||
if "final_info" in info:
|
||||
successes = [info["is_success"] if info is not None else False for info in info["final_info"]]
|
||||
successes = [
|
||||
info["is_success"] if info is not None else False
|
||||
for info in info["final_info"]
|
||||
]
|
||||
else:
|
||||
successes = [False] * env.num_envs
|
||||
|
||||
@@ -188,9 +191,13 @@ def rollout(
|
||||
|
||||
step += 1
|
||||
running_success_rate = (
|
||||
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean()
|
||||
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any")
|
||||
.numpy()
|
||||
.mean()
|
||||
)
|
||||
progbar.set_postfix(
|
||||
{"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"}
|
||||
)
|
||||
progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"})
|
||||
progbar.update()
|
||||
|
||||
# Track the final observation.
|
||||
@@ -208,7 +215,9 @@ def rollout(
|
||||
if return_observations:
|
||||
stacked_observations = {}
|
||||
for key in all_observations[0]:
|
||||
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
|
||||
stacked_observations[key] = torch.stack(
|
||||
[obs[key] for obs in all_observations], dim=1
|
||||
)
|
||||
ret["observation"] = stacked_observations
|
||||
|
||||
if hasattr(policy, "use_original_modules"):
|
||||
@@ -270,7 +279,9 @@ def eval_policy(
|
||||
return
|
||||
n_to_render_now = min(max_episodes_rendered - n_episodes_rendered, env.num_envs)
|
||||
if isinstance(env, gym.vector.SyncVectorEnv):
|
||||
ep_frames.append(np.stack([env.envs[i].render() for i in range(n_to_render_now)])) # noqa: B023
|
||||
ep_frames.append(
|
||||
np.stack([env.envs[i].render() for i in range(n_to_render_now)])
|
||||
) # noqa: B023
|
||||
elif isinstance(env, gym.vector.AsyncVectorEnv):
|
||||
# Here we must render all frames and discard any we don't need.
|
||||
ep_frames.append(np.stack(env.call("render")[:n_to_render_now]))
|
||||
@@ -282,7 +293,9 @@ def eval_policy(
|
||||
episode_data: dict | None = None
|
||||
|
||||
# we dont want progress bar when we use slurm, since it clutters the logs
|
||||
progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm())
|
||||
progbar = trange(
|
||||
n_batches, desc="Stepping through eval batches", disable=inside_slurm()
|
||||
)
|
||||
for batch_ix in progbar:
|
||||
# Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout
|
||||
# step.
|
||||
@@ -293,7 +306,8 @@ def eval_policy(
|
||||
seeds = None
|
||||
else:
|
||||
seeds = range(
|
||||
start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs)
|
||||
start_seed + (batch_ix * env.num_envs),
|
||||
start_seed + ((batch_ix + 1) * env.num_envs),
|
||||
)
|
||||
rollout_data = rollout(
|
||||
env,
|
||||
@@ -311,13 +325,22 @@ def eval_policy(
|
||||
|
||||
# Make a mask with shape (batch, n_steps) to mask out rollout data after the first done
|
||||
# (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step.
|
||||
mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int()
|
||||
mask = (
|
||||
torch.arange(n_steps)
|
||||
<= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)
|
||||
).int()
|
||||
# Extend metrics.
|
||||
batch_sum_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "sum")
|
||||
batch_sum_rewards = einops.reduce(
|
||||
(rollout_data["reward"] * mask), "b n -> b", "sum"
|
||||
)
|
||||
sum_rewards.extend(batch_sum_rewards.tolist())
|
||||
batch_max_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "max")
|
||||
batch_max_rewards = einops.reduce(
|
||||
(rollout_data["reward"] * mask), "b n -> b", "max"
|
||||
)
|
||||
max_rewards.extend(batch_max_rewards.tolist())
|
||||
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
|
||||
batch_successes = einops.reduce(
|
||||
(rollout_data["success"] * mask), "b n -> b", "any"
|
||||
)
|
||||
all_successes.extend(batch_successes.tolist())
|
||||
if seeds:
|
||||
all_seeds.extend(seeds)
|
||||
@@ -330,17 +353,27 @@ def eval_policy(
|
||||
rollout_data,
|
||||
done_indices,
|
||||
start_episode_index=batch_ix * env.num_envs,
|
||||
start_data_index=(0 if episode_data is None else (episode_data["index"][-1].item() + 1)),
|
||||
start_data_index=(
|
||||
0
|
||||
if episode_data is None
|
||||
else (episode_data["index"][-1].item() + 1)
|
||||
),
|
||||
fps=env.unwrapped.metadata["render_fps"],
|
||||
)
|
||||
if episode_data is None:
|
||||
episode_data = this_episode_data
|
||||
else:
|
||||
# Some sanity checks to make sure we are correctly compiling the data.
|
||||
assert episode_data["episode_index"][-1] + 1 == this_episode_data["episode_index"][0]
|
||||
assert (
|
||||
episode_data["episode_index"][-1] + 1
|
||||
== this_episode_data["episode_index"][0]
|
||||
)
|
||||
assert episode_data["index"][-1] + 1 == this_episode_data["index"][0]
|
||||
# Concatenate the episode data.
|
||||
episode_data = {k: torch.cat([episode_data[k], this_episode_data[k]]) for k in episode_data}
|
||||
episode_data = {
|
||||
k: torch.cat([episode_data[k], this_episode_data[k]])
|
||||
for k in episode_data
|
||||
}
|
||||
|
||||
# Maybe render video for visualization.
|
||||
if max_episodes_rendered > 0 and len(ep_frames) > 0:
|
||||
@@ -358,7 +391,9 @@ def eval_policy(
|
||||
target=write_video,
|
||||
args=(
|
||||
str(video_path),
|
||||
stacked_frames[: done_index + 1], # + 1 to capture the last observation
|
||||
stacked_frames[
|
||||
: done_index + 1
|
||||
], # + 1 to capture the last observation
|
||||
env.unwrapped.metadata["render_fps"],
|
||||
),
|
||||
)
|
||||
@@ -367,7 +402,9 @@ def eval_policy(
|
||||
n_episodes_rendered += 1
|
||||
|
||||
progbar.set_postfix(
|
||||
{"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"}
|
||||
{
|
||||
"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"
|
||||
}
|
||||
)
|
||||
|
||||
# Wait till all video rendering threads are done.
|
||||
@@ -413,7 +450,11 @@ def eval_policy(
|
||||
|
||||
|
||||
def _compile_episode_data(
|
||||
rollout_data: dict, done_indices: Tensor, start_episode_index: int, start_data_index: int, fps: float
|
||||
rollout_data: dict,
|
||||
done_indices: Tensor,
|
||||
start_episode_index: int,
|
||||
start_data_index: int,
|
||||
fps: float,
|
||||
) -> dict:
|
||||
"""Convenience function for `eval_policy(return_episode_data=True)`
|
||||
|
||||
@@ -431,12 +472,16 @@ def _compile_episode_data(
|
||||
# Here we do `num_frames - 1` as we don't want to include the last observation frame just yet.
|
||||
ep_dict = {
|
||||
"action": rollout_data["action"][ep_ix, : num_frames - 1],
|
||||
"episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)),
|
||||
"episode_index": torch.tensor(
|
||||
[start_episode_index + ep_ix] * (num_frames - 1)
|
||||
),
|
||||
"frame_index": torch.arange(0, num_frames - 1, 1),
|
||||
"timestamp": torch.arange(0, num_frames - 1, 1) / fps,
|
||||
"next.done": rollout_data["done"][ep_ix, : num_frames - 1],
|
||||
"next.success": rollout_data["success"][ep_ix, : num_frames - 1],
|
||||
"next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32),
|
||||
"next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(
|
||||
torch.float32
|
||||
),
|
||||
}
|
||||
|
||||
# For the last observation frame, all other keys will just be copy padded.
|
||||
@@ -452,7 +497,9 @@ def _compile_episode_data(
|
||||
for key in ep_dicts[0]:
|
||||
data_dict[key] = torch.cat([x[key] for x in ep_dicts])
|
||||
|
||||
data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1)
|
||||
data_dict["index"] = torch.arange(
|
||||
start_data_index, start_data_index + total_frames, 1
|
||||
)
|
||||
|
||||
return data_dict
|
||||
|
||||
|
||||
@@ -46,7 +46,11 @@ import torch
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position
|
||||
from lerobot.common.robot_devices.control_utils import (
|
||||
busy_wait,
|
||||
is_headless,
|
||||
reset_follower_position,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.factory import Robot, make_robot
|
||||
from lerobot.common.utils.utils import (
|
||||
init_hydra_config,
|
||||
@@ -60,13 +64,19 @@ def get_classifier(pretrained_path, config_path):
|
||||
return
|
||||
|
||||
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
|
||||
ClassifierConfig,
|
||||
)
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
|
||||
Classifier,
|
||||
)
|
||||
|
||||
cfg = init_hydra_config(config_path)
|
||||
|
||||
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
||||
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
|
||||
classifier_config.num_cameras = len(
|
||||
cfg.training.image_keys
|
||||
) # TODO automate these paths
|
||||
model = Classifier(classifier_config)
|
||||
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
|
||||
model = model.to("mps")
|
||||
@@ -151,11 +161,17 @@ def rollout(
|
||||
images = []
|
||||
for key in image_keys:
|
||||
if display_cameras:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.imshow(
|
||||
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
|
||||
)
|
||||
cv2.waitKey(1)
|
||||
images.append(observation[key].to("mps"))
|
||||
|
||||
reward = reward_classifier.predict_reward(images) if reward_classifier is not None else 0.0
|
||||
reward = (
|
||||
reward_classifier.predict_reward(images)
|
||||
if reward_classifier is not None
|
||||
else 0.0
|
||||
)
|
||||
all_rewards.append(reward)
|
||||
|
||||
# print("REWARD : ", reward)
|
||||
@@ -219,11 +235,19 @@ def eval_policy(
|
||||
|
||||
start_eval = time.perf_counter()
|
||||
progbar = trange(n_episodes, desc="Evaluating policy on real robot")
|
||||
reward_classifier = get_classifier(reward_classifier_pretrained_path, reward_classifier_config_file)
|
||||
reward_classifier = get_classifier(
|
||||
reward_classifier_pretrained_path, reward_classifier_config_file
|
||||
)
|
||||
|
||||
for _ in progbar:
|
||||
rollout_data = rollout(
|
||||
robot, policy, reward_classifier, fps, control_time_s, use_amp, display_cameras
|
||||
robot,
|
||||
policy,
|
||||
reward_classifier,
|
||||
fps,
|
||||
control_time_s,
|
||||
use_amp,
|
||||
display_cameras,
|
||||
)
|
||||
|
||||
rollouts.append(rollout_data)
|
||||
@@ -289,7 +313,9 @@ def init_keyboard_listener():
|
||||
print("Right arrow key pressed. Exiting loop...")
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
||||
print(
|
||||
"Left arrow key pressed. Exiting loop and rerecord the last episode..."
|
||||
)
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.space:
|
||||
@@ -301,7 +327,10 @@ def init_keyboard_listener():
|
||||
"Place the leader in similar pose to the follower and press space again."
|
||||
)
|
||||
events["pause_policy"] = True
|
||||
log_say("Human intervention stage. Get ready to take over.", play_sounds=True)
|
||||
log_say(
|
||||
"Human intervention stage. Get ready to take over.",
|
||||
play_sounds=True,
|
||||
)
|
||||
else:
|
||||
events["human_intervention_step"] = True
|
||||
print("Space key pressed. Human intervention starting.")
|
||||
@@ -351,7 +380,9 @@ if __name__ == "__main__":
|
||||
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
|
||||
parser.add_argument(
|
||||
"--revision", help="Optionally provide the Hugging Face Hub revision ID."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out-dir",
|
||||
help=(
|
||||
@@ -360,7 +391,8 @@ if __name__ == "__main__":
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--display-cameras", help=("Whether to display the camera feed while the rollout is happening")
|
||||
"--display-cameras",
|
||||
help=("Whether to display the camera feed while the rollout is happening"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward-classifier-pretrained-path",
|
||||
|
||||
@@ -45,9 +45,13 @@ def find_port():
|
||||
print(f"The port of this MotorsBus is '{port}'")
|
||||
print("Reconnect the USB cable.")
|
||||
elif len(ports_diff) == 0:
|
||||
raise OSError(f"Could not detect the port. No difference was found ({ports_diff}).")
|
||||
raise OSError(
|
||||
f"Could not detect the port. No difference was found ({ports_diff})."
|
||||
)
|
||||
else:
|
||||
raise OSError(f"Could not detect the port. More than one port was found ({ports_diff}).")
|
||||
raise OSError(
|
||||
f"Could not detect the port. More than one port was found ({ports_diff})."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import functools
|
||||
import random
|
||||
from typing import Any, Callable, Optional, Sequence, TypedDict
|
||||
|
||||
import io
|
||||
@@ -737,7 +736,6 @@ def concatenate_batch_transitions(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import numpy as np
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
# ===== Test 1: Create and use a synthetic ReplayBuffer =====
|
||||
@@ -1139,7 +1137,7 @@ if __name__ == "__main__":
|
||||
|
||||
savings_percent = (std_mem - opt_mem) / std_mem * 100
|
||||
|
||||
print(f"\nMemory optimization result:")
|
||||
print("\nMemory optimization result:")
|
||||
print(f"- Standard buffer state memory: {std_mem / (1024 * 1024):.2f} MB")
|
||||
print(f"- Optimized buffer state memory: {opt_mem / (1024 * 1024):.2f} MB")
|
||||
print(f"- Memory savings for state tensors: {savings_percent:.1f}%")
|
||||
|
||||
@@ -225,7 +225,9 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Crop rectangular ROIs from a LeRobot dataset.")
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Crop rectangular ROIs from a LeRobot dataset."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
@@ -247,7 +249,9 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
local_files_only = args.root is not None
|
||||
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root, local_files_only=local_files_only)
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=args.repo_id, root=args.root, local_files_only=local_files_only
|
||||
)
|
||||
|
||||
images = get_image_from_lerobot_dataset(dataset)
|
||||
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
|
||||
@@ -256,7 +260,7 @@ if __name__ == "__main__":
|
||||
if args.crop_params_path is None:
|
||||
rois = select_square_roi_for_images(images)
|
||||
else:
|
||||
with open(args.crop_params_path, "r") as f:
|
||||
with open(args.crop_params_path) as f:
|
||||
rois = json.load(f)
|
||||
|
||||
# rois = {
|
||||
|
||||
@@ -31,7 +31,9 @@ def find_joint_bounds(
|
||||
if display_cameras and not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.imshow(
|
||||
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
|
||||
)
|
||||
cv2.waitKey(1)
|
||||
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
@@ -57,7 +59,12 @@ if __name__ == "__main__":
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds")
|
||||
parser.add_argument(
|
||||
"--control-time-s",
|
||||
type=float,
|
||||
default=20,
|
||||
help="Maximum episode length in seconds",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
|
||||
|
||||
|
||||
@@ -146,7 +146,7 @@ def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
|
||||
|
||||
|
||||
def initialize_replay_buffer(
|
||||
cfg: DictConfig, logger: Logger, device: str, storage_device:str
|
||||
cfg: DictConfig, logger: Logger, device: str, storage_device: str
|
||||
) -> ReplayBuffer:
|
||||
if not cfg.resume:
|
||||
return ReplayBuffer(
|
||||
|
||||
@@ -10,7 +10,9 @@ from typing import Any
|
||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||
|
||||
|
||||
def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, torch.Tensor]:
|
||||
def preprocess_maniskill_observation(
|
||||
observations: dict[str, np.ndarray],
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Convert environment observation to LeRobot format observation.
|
||||
Args:
|
||||
observation: Dictionary of observation batches from a Gym vector environment.
|
||||
@@ -62,7 +64,9 @@ class ManiSkillCompat(gym.Wrapper):
|
||||
new_action_space_shape = env.action_space.shape[-1]
|
||||
new_low = np.squeeze(env.action_space.low, axis=0)
|
||||
new_high = np.squeeze(env.action_space.high, axis=0)
|
||||
self.action_space = gym.spaces.Box(low=new_low, high=new_high, shape=(new_action_space_shape,))
|
||||
self.action_space = gym.spaces.Box(
|
||||
low=new_low, high=new_high, shape=(new_action_space_shape,)
|
||||
)
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
@@ -81,7 +85,9 @@ class ManiSkillCompat(gym.Wrapper):
|
||||
class ManiSkillActionWrapper(gym.ActionWrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
self.action_space = gym.spaces.Tuple(spaces=(env.action_space, gym.spaces.Discrete(2)))
|
||||
self.action_space = gym.spaces.Tuple(
|
||||
spaces=(env.action_space, gym.spaces.Discrete(2))
|
||||
)
|
||||
|
||||
def action(self, action):
|
||||
action, telop = action
|
||||
@@ -95,7 +101,9 @@ class ManiSkillMultiplyActionWrapper(gym.Wrapper):
|
||||
action_space_agent: gym.spaces.Box = env.action_space[0]
|
||||
action_space_agent.low = action_space_agent.low * multiply_factor
|
||||
action_space_agent.high = action_space_agent.high * multiply_factor
|
||||
self.action_space = gym.spaces.Tuple(spaces=(action_space_agent, gym.spaces.Discrete(2)))
|
||||
self.action_space = gym.spaces.Tuple(
|
||||
spaces=(action_space_agent, gym.spaces.Discrete(2))
|
||||
)
|
||||
|
||||
def step(self, action):
|
||||
if isinstance(action, tuple):
|
||||
@@ -137,7 +145,9 @@ def make_maniskill(
|
||||
|
||||
env = ManiSkillObservationWrapper(env, device=cfg.env.device)
|
||||
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
|
||||
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
|
||||
env._max_episode_steps = env.max_episode_steps = (
|
||||
50 # gym_utils.find_max_episode_steps_value(env)
|
||||
)
|
||||
env.unwrapped.metadata["render_fps"] = 20
|
||||
env = ManiSkillCompat(env)
|
||||
env = ManiSkillActionWrapper(env)
|
||||
@@ -149,10 +159,11 @@ def make_maniskill(
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
import hydra
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default="lerobot/configs/env/maniskill_example.yaml")
|
||||
parser.add_argument(
|
||||
"--config", type=str, default="lerobot/configs/env/maniskill_example.yaml"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize config
|
||||
|
||||
@@ -73,7 +73,9 @@ def make_optimizer_and_scheduler(cfg, policy):
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(
|
||||
optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
|
||||
optimizer_params_dicts,
|
||||
lr=cfg.training.lr,
|
||||
weight_decay=cfg.training.weight_decay,
|
||||
)
|
||||
lr_scheduler = None
|
||||
elif cfg.policy.name == "diffusion":
|
||||
@@ -100,14 +102,23 @@ def make_optimizer_and_scheduler(cfg, policy):
|
||||
optimizer = torch.optim.Adam(
|
||||
[
|
||||
{"params": policy.actor.parameters(), "lr": policy.config.actor_lr},
|
||||
{"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr},
|
||||
{"params": policy.temperature.parameters(), "lr": policy.config.temperature_lr},
|
||||
{
|
||||
"params": policy.critic_ensemble.parameters(),
|
||||
"lr": policy.config.critic_lr,
|
||||
},
|
||||
{
|
||||
"params": policy.temperature.parameters(),
|
||||
"lr": policy.config.temperature_lr,
|
||||
},
|
||||
]
|
||||
)
|
||||
lr_scheduler = None
|
||||
|
||||
elif cfg.policy.name == "vqbet":
|
||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler
|
||||
from lerobot.common.policies.vqbet.modeling_vqbet import (
|
||||
VQBeTOptimizer,
|
||||
VQBeTScheduler,
|
||||
)
|
||||
|
||||
optimizer = VQBeTOptimizer(policy, cfg)
|
||||
lr_scheduler = VQBeTScheduler(optimizer, cfg)
|
||||
@@ -214,7 +225,9 @@ def train(cfg: TrainPipelineConfig):
|
||||
if cfg.resume:
|
||||
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_learnable_params = sum(
|
||||
p.numel() for p in policy.parameters() if p.requires_grad
|
||||
)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
import logging
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
|
||||
import hydra
|
||||
@@ -28,14 +27,16 @@ from termcolor import colored
|
||||
from torch import optim
|
||||
from torch.autograd import profiler
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.utils.data import DataLoader, RandomSampler, WeightedRandomSampler, random_split
|
||||
from torch.utils.data import DataLoader, RandomSampler, WeightedRandomSampler
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.factory import resolve_delta_timestamps
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.logger import Logger
|
||||
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
|
||||
ClassifierConfig,
|
||||
)
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
@@ -50,7 +51,11 @@ def get_model(cfg, logger): # noqa I001
|
||||
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
||||
model = Classifier(classifier_config)
|
||||
if cfg.resume:
|
||||
model.load_state_dict(Classifier.from_pretrained(str(logger.last_pretrained_model_dir)).state_dict())
|
||||
model.load_state_dict(
|
||||
Classifier.from_pretrained(
|
||||
str(logger.last_pretrained_model_dir)
|
||||
).state_dict()
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
@@ -62,7 +67,9 @@ def create_balanced_sampler(dataset, cfg):
|
||||
class_weights = 1.0 / counts.float()
|
||||
sample_weights = class_weights[labels]
|
||||
|
||||
return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
|
||||
return WeightedRandomSampler(
|
||||
weights=sample_weights, num_samples=len(sample_weights), replacement=True
|
||||
)
|
||||
|
||||
|
||||
def support_amp(device: torch.device, cfg: DictConfig) -> bool:
|
||||
@@ -71,7 +78,9 @@ def support_amp(device: torch.device, cfg: DictConfig) -> bool:
|
||||
return cfg.training.use_amp and device.type in ("cuda", "cpu")
|
||||
|
||||
|
||||
def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg):
|
||||
def train_epoch(
|
||||
model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg
|
||||
):
|
||||
# Single epoch training loop with AMP support and progress tracking
|
||||
model.train()
|
||||
correct = 0
|
||||
@@ -85,7 +94,11 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device,
|
||||
labels = batch[cfg.training.label_key].float().to(device)
|
||||
|
||||
# Forward pass with optional AMP
|
||||
with torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext():
|
||||
with (
|
||||
torch.autocast(device_type=device.type)
|
||||
if support_amp(device, cfg)
|
||||
else nullcontext()
|
||||
):
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs.logits, labels)
|
||||
|
||||
@@ -130,7 +143,9 @@ def validate(model, val_loader, criterion, device, logger, cfg):
|
||||
|
||||
with (
|
||||
torch.no_grad(),
|
||||
torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(),
|
||||
torch.autocast(device_type=device.type)
|
||||
if support_amp(device, cfg)
|
||||
else nullcontext(),
|
||||
):
|
||||
for batch in tqdm(val_loader, desc="Validation"):
|
||||
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
|
||||
@@ -143,7 +158,9 @@ def validate(model, val_loader, criterion, device, logger, cfg):
|
||||
):
|
||||
outputs = model(images)
|
||||
inference_times.append(
|
||||
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
|
||||
next(
|
||||
x for x in prof.key_averages() if x.key == "model_inference"
|
||||
).cpu_time
|
||||
)
|
||||
else:
|
||||
outputs = model(images)
|
||||
@@ -161,16 +178,24 @@ def validate(model, val_loader, criterion, device, logger, cfg):
|
||||
|
||||
# Log sample predictions for visualization
|
||||
if len(samples) < cfg.eval.num_samples_to_log:
|
||||
for i in range(min(cfg.eval.num_samples_to_log - len(samples), len(images))):
|
||||
for i in range(
|
||||
min(cfg.eval.num_samples_to_log - len(samples), len(images))
|
||||
):
|
||||
if model.config.num_classes == 2:
|
||||
confidence = round(outputs.probabilities[i].item(), 3)
|
||||
else:
|
||||
confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()]
|
||||
confidence = [
|
||||
round(prob, 3) for prob in outputs.probabilities[i].tolist()
|
||||
]
|
||||
samples.append(
|
||||
{
|
||||
**{
|
||||
f"image_{img_key}": wandb.Image(images[img_idx][i].cpu())
|
||||
for img_idx, img_key in enumerate(cfg.training.image_keys)
|
||||
f"image_{img_key}": wandb.Image(
|
||||
images[img_idx][i].cpu()
|
||||
)
|
||||
for img_idx, img_key in enumerate(
|
||||
cfg.training.image_keys
|
||||
)
|
||||
},
|
||||
"true_label": labels[i].item(),
|
||||
"predicted": predictions[i].item(),
|
||||
@@ -238,15 +263,24 @@ def benchmark_inference_time(model, dataset, logger, cfg, device, step):
|
||||
elif device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
|
||||
with profiler.profile(record_shapes=True) as prof, profiler.record_function("model_inference"):
|
||||
with (
|
||||
profiler.profile(record_shapes=True) as prof,
|
||||
profiler.record_function("model_inference"),
|
||||
):
|
||||
_ = model(x)
|
||||
|
||||
inference_times.append(
|
||||
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
|
||||
next(
|
||||
x for x in prof.key_averages() if x.key == "model_inference"
|
||||
).cpu_time
|
||||
)
|
||||
|
||||
inference_times = np.array(inference_times)
|
||||
avg, median, std = inference_times.mean(), np.median(inference_times), inference_times.std()
|
||||
avg, median, std = (
|
||||
inference_times.mean(),
|
||||
np.median(inference_times),
|
||||
inference_times.std(),
|
||||
)
|
||||
print(
|
||||
f"Inference time mean: {avg:.2f} us, median: {median:.2f} us, std: {std:.2f} us, with {iters} iterations on {device.type} device"
|
||||
)
|
||||
@@ -264,7 +298,11 @@ def benchmark_inference_time(model, dataset, logger, cfg, device, step):
|
||||
return avg, median, std
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_path="../configs/policy", config_name="hilserl_classifier")
|
||||
@hydra.main(
|
||||
version_base="1.2",
|
||||
config_path="../configs/policy",
|
||||
config_name="hilserl_classifier",
|
||||
)
|
||||
def train(cfg: DictConfig) -> None:
|
||||
# Main training pipeline with support for resuming training
|
||||
logging.info(OmegaConf.to_yaml(cfg))
|
||||
@@ -278,7 +316,9 @@ def train(cfg: DictConfig) -> None:
|
||||
|
||||
# Setup dataset and dataloaders
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset_repo_id, root=cfg.dataset_root, local_files_only=cfg.local_files_only
|
||||
cfg.dataset_repo_id,
|
||||
root=cfg.dataset_root,
|
||||
local_files_only=cfg.local_files_only,
|
||||
)
|
||||
logging.info(f"Dataset size: {len(dataset)}")
|
||||
|
||||
@@ -314,7 +354,9 @@ def train(cfg: DictConfig) -> None:
|
||||
"You have set resume=True, but there is no model checkpoint in "
|
||||
f"{Logger.get_last_checkpoint_dir(out_dir)}"
|
||||
)
|
||||
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
|
||||
checkpoint_cfg_path = str(
|
||||
Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml"
|
||||
)
|
||||
logging.info(
|
||||
colored(
|
||||
"You have set resume=True, indicating that you wish to resume a run",
|
||||
@@ -327,7 +369,9 @@ def train(cfg: DictConfig) -> None:
|
||||
# Check for differences between the checkpoint configuration and provided configuration.
|
||||
# Hack to resolve the delta_timestamps ahead of time in order to properly diff.
|
||||
resolve_delta_timestamps(cfg)
|
||||
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
|
||||
diff = DeepDiff(
|
||||
OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg)
|
||||
)
|
||||
# Ignore the `resume` and parameters.
|
||||
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
|
||||
del diff["values_changed"]["root['resume']"]
|
||||
@@ -346,7 +390,11 @@ def train(cfg: DictConfig) -> None:
|
||||
|
||||
optimizer = optim.AdamW(model.parameters(), lr=cfg.training.learning_rate)
|
||||
# Use BCEWithLogitsLoss for binary classification and CrossEntropyLoss for multi-class
|
||||
criterion = nn.BCEWithLogitsLoss() if model.config.num_classes == 2 else nn.CrossEntropyLoss()
|
||||
criterion = (
|
||||
nn.BCEWithLogitsLoss()
|
||||
if model.config.num_classes == 2
|
||||
else nn.CrossEntropyLoss()
|
||||
)
|
||||
grad_scaler = GradScaler(enabled=cfg.training.use_amp)
|
||||
|
||||
# Log model parameters
|
||||
@@ -362,7 +410,17 @@ def train(cfg: DictConfig) -> None:
|
||||
for epoch in range(cfg.training.num_epochs):
|
||||
logging.info(f"\nEpoch {epoch+1}/{cfg.training.num_epochs}")
|
||||
|
||||
train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg)
|
||||
train_epoch(
|
||||
model,
|
||||
train_loader,
|
||||
criterion,
|
||||
optimizer,
|
||||
grad_scaler,
|
||||
device,
|
||||
logger,
|
||||
step,
|
||||
cfg,
|
||||
)
|
||||
|
||||
# Periodic validation
|
||||
if cfg.training.eval_freq > 0 and (epoch + 1) % cfg.training.eval_freq == 0:
|
||||
|
||||
@@ -22,7 +22,6 @@ from typing import Callable, Optional, Sequence, TypedDict
|
||||
import hydra
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
@@ -30,20 +29,17 @@ from tqdm import tqdm
|
||||
# TODO: Remove the import of maniskill
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.envs.factory import make_env, make_maniskill_env
|
||||
from lerobot.common.envs.utils import preprocess_maniskill_observation, preprocess_observation
|
||||
from lerobot.common.envs.factory import make_maniskill_env
|
||||
from lerobot.common.envs.utils import preprocess_maniskill_observation
|
||||
from lerobot.common.logger import Logger, log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
init_hydra_config,
|
||||
init_logging,
|
||||
set_global_seed,
|
||||
)
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
|
||||
|
||||
def make_optimizers_and_scheduler(cfg, policy):
|
||||
@@ -56,7 +52,9 @@ def make_optimizers_and_scheduler(cfg, policy):
|
||||
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
|
||||
)
|
||||
# We wrap policy log temperature in list because this is a torch tensor and not a nn.Module
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
|
||||
optimizer_temperature = torch.optim.Adam(
|
||||
params=[policy.log_alpha], lr=policy.config.critic_lr
|
||||
)
|
||||
lr_scheduler = None
|
||||
optimizers = {
|
||||
"actor": optimizer_actor,
|
||||
@@ -108,7 +106,9 @@ def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Te
|
||||
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
|
||||
|
||||
# Gather pixels
|
||||
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
|
||||
cropped_hwcn = images_hwcn[
|
||||
torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :
|
||||
]
|
||||
# cropped_hwcn => (B, crop_h, crop_w, C)
|
||||
|
||||
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
|
||||
@@ -198,8 +198,12 @@ class ReplayBuffer:
|
||||
"""
|
||||
# We convert the LeRobotDataset into a replay buffer, because it is more efficient to sample from
|
||||
# a replay buffer than from a lerobot dataset.
|
||||
replay_buffer = cls(capacity=len(lerobot_dataset), device=device, state_keys=state_keys)
|
||||
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
|
||||
replay_buffer = cls(
|
||||
capacity=len(lerobot_dataset), device=device, state_keys=state_keys
|
||||
)
|
||||
list_transition = cls._lerobotdataset_to_transitions(
|
||||
dataset=lerobot_dataset, state_keys=state_keys
|
||||
)
|
||||
# Fill the replay buffer with the lerobot dataset transitions
|
||||
for data in list_transition:
|
||||
replay_buffer.add(
|
||||
@@ -244,7 +248,9 @@ class ReplayBuffer:
|
||||
|
||||
# If not provided, you can either raise an error or define a default:
|
||||
if state_keys is None:
|
||||
raise ValueError("You must provide a list of keys in `state_keys` that define your 'state'.")
|
||||
raise ValueError(
|
||||
"You must provide a list of keys in `state_keys` that define your 'state'."
|
||||
)
|
||||
|
||||
transitions: list[Transition] = []
|
||||
num_frames = len(dataset)
|
||||
@@ -298,36 +304,40 @@ class ReplayBuffer:
|
||||
# -- Build batched states --
|
||||
batch_state = {}
|
||||
for key in self.state_keys:
|
||||
batch_state[key] = torch.cat([t["state"][key] for t in list_of_transitions], dim=0).to(
|
||||
self.device
|
||||
)
|
||||
batch_state[key] = torch.cat(
|
||||
[t["state"][key] for t in list_of_transitions], dim=0
|
||||
).to(self.device)
|
||||
if key.startswith("observation.image") and self.use_drq:
|
||||
batch_state[key] = self.image_augmentation_function(batch_state[key])
|
||||
|
||||
# -- Build batched actions --
|
||||
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(self.device)
|
||||
|
||||
# -- Build batched rewards --
|
||||
batch_rewards = torch.tensor([t["reward"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(
|
||||
self.device
|
||||
)
|
||||
|
||||
# -- Build batched rewards --
|
||||
batch_rewards = torch.tensor(
|
||||
[t["reward"] for t in list_of_transitions], dtype=torch.float32
|
||||
).to(self.device)
|
||||
|
||||
# -- Build batched next states --
|
||||
batch_next_state = {}
|
||||
for key in self.state_keys:
|
||||
batch_next_state[key] = torch.cat([t["next_state"][key] for t in list_of_transitions], dim=0).to(
|
||||
self.device
|
||||
)
|
||||
batch_next_state[key] = torch.cat(
|
||||
[t["next_state"][key] for t in list_of_transitions], dim=0
|
||||
).to(self.device)
|
||||
if key.startswith("observation.image") and self.use_drq:
|
||||
batch_next_state[key] = self.image_augmentation_function(batch_next_state[key])
|
||||
batch_next_state[key] = self.image_augmentation_function(
|
||||
batch_next_state[key]
|
||||
)
|
||||
|
||||
# -- Build batched dones --
|
||||
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
batch_dones = torch.tensor(
|
||||
[t["done"] for t in list_of_transitions], dtype=torch.float32
|
||||
).to(self.device)
|
||||
batch_dones = torch.tensor(
|
||||
[t["done"] for t in list_of_transitions], dtype=torch.float32
|
||||
).to(self.device)
|
||||
|
||||
# Return a BatchTransition typed dict
|
||||
return BatchTransition(
|
||||
@@ -344,7 +354,13 @@ def concatenate_batch_transitions(
|
||||
) -> BatchTransition:
|
||||
"""NOTE: Be careful it change the left_batch_transitions in place"""
|
||||
left_batch_transitions["state"] = {
|
||||
key: torch.cat([left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0)
|
||||
key: torch.cat(
|
||||
[
|
||||
left_batch_transitions["state"][key],
|
||||
right_batch_transition["state"][key],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
for key in left_batch_transitions["state"]
|
||||
}
|
||||
left_batch_transitions["action"] = torch.cat(
|
||||
@@ -355,7 +371,11 @@ def concatenate_batch_transitions(
|
||||
)
|
||||
left_batch_transitions["next_state"] = {
|
||||
key: torch.cat(
|
||||
[left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]], dim=0
|
||||
[
|
||||
left_batch_transitions["next_state"][key],
|
||||
right_batch_transition["next_state"][key],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
for key in left_batch_transitions["next_state"]
|
||||
}
|
||||
@@ -407,7 +427,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||
# Hack: But if we do online traning, we do not need dataset_stats
|
||||
dataset_stats=None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir)
|
||||
if cfg.resume
|
||||
else None,
|
||||
device=device,
|
||||
)
|
||||
assert isinstance(policy, nn.Module)
|
||||
@@ -416,7 +438,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
|
||||
# TODO: Handle resume
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_learnable_params = sum(
|
||||
p.numel() for p in policy.parameters() if p.requires_grad
|
||||
)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
log_output_dir(out_dir)
|
||||
@@ -433,7 +457,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
obs = {key: obs[key].to(device, non_blocking=True) for key in obs}
|
||||
|
||||
replay_buffer = ReplayBuffer(
|
||||
capacity=cfg.training.online_buffer_capacity, device=device, state_keys=cfg.policy.input_shapes.keys()
|
||||
capacity=cfg.training.online_buffer_capacity,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
)
|
||||
|
||||
batch_size = cfg.training.batch_size
|
||||
@@ -455,12 +481,16 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
|
||||
if interaction_step >= cfg.training.online_step_before_learning:
|
||||
action = policy.select_action(batch=obs)
|
||||
next_obs, reward, done, truncated, info = online_env.step(action.cpu().numpy())
|
||||
next_obs, reward, done, truncated, info = online_env.step(
|
||||
action.cpu().numpy()
|
||||
)
|
||||
else:
|
||||
action = online_env.action_space.sample()
|
||||
next_obs, reward, done, truncated, info = online_env.step(action)
|
||||
# HACK
|
||||
action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True)
|
||||
action = torch.tensor(action, dtype=torch.float32).to(
|
||||
device, non_blocking=True
|
||||
)
|
||||
|
||||
# HACK: For maniskill
|
||||
# next_obs = preprocess_observation(next_obs)
|
||||
@@ -470,14 +500,20 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
# Because we are using a single environment
|
||||
# we can safely assume that the episode is done
|
||||
if done[0] or truncated[0]:
|
||||
logging.info(f"Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||
logger.log_dict({"Sum episode reward": sum_reward_episode}, interaction_step)
|
||||
logging.info(
|
||||
f"Global step {interaction_step}: Episode reward: {sum_reward_episode}"
|
||||
)
|
||||
logger.log_dict(
|
||||
{"Sum episode reward": sum_reward_episode}, interaction_step
|
||||
)
|
||||
sum_reward_episode = 0
|
||||
# HACK: This is for maniskill
|
||||
logging.info(
|
||||
f"global step {interaction_step}: episode success: {info['success'].float().item()} \n"
|
||||
)
|
||||
logger.log_dict({"Episode success": info["success"].float().item()}, interaction_step)
|
||||
logger.log_dict(
|
||||
{"Episode success": info["success"].float().item()}, interaction_step
|
||||
)
|
||||
|
||||
replay_buffer.add(
|
||||
state=obs,
|
||||
@@ -551,7 +587,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
|
||||
training_infos["loss_actor"] = loss_actor.item()
|
||||
|
||||
loss_temperature = policy.compute_loss_temperature(observations=observations)
|
||||
loss_temperature = policy.compute_loss_temperature(
|
||||
observations=observations
|
||||
)
|
||||
optimizers["temperature"].zero_grad()
|
||||
loss_temperature.backward()
|
||||
optimizers["temperature"].step()
|
||||
@@ -573,7 +611,9 @@ def train_cli(cfg: dict):
|
||||
)
|
||||
|
||||
|
||||
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
|
||||
def train_notebook(
|
||||
out_dir=None, job_name=None, config_name="default", config_path="../configs"
|
||||
):
|
||||
from hydra import compose, initialize
|
||||
|
||||
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
||||
|
||||
@@ -94,8 +94,12 @@ def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
|
||||
assert chw_float32_torch.dtype == torch.float32
|
||||
assert chw_float32_torch.ndim == 3
|
||||
c, h, w = chw_float32_torch.shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}"
|
||||
hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
|
||||
assert (
|
||||
c < h and c < w
|
||||
), f"expect channel first images, but instead {chw_float32_torch.shape}"
|
||||
hwc_uint8_numpy = (
|
||||
(chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy()
|
||||
)
|
||||
return hwc_uint8_numpy
|
||||
|
||||
|
||||
|
||||
@@ -81,7 +81,11 @@ def run_server(
|
||||
static_folder: Path,
|
||||
template_folder: Path,
|
||||
):
|
||||
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
|
||||
app = Flask(
|
||||
__name__,
|
||||
static_folder=static_folder.resolve(),
|
||||
template_folder=template_folder.resolve(),
|
||||
)
|
||||
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
|
||||
|
||||
@app.route("/")
|
||||
@@ -138,8 +142,12 @@ def run_server(
|
||||
)
|
||||
)
|
||||
|
||||
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
|
||||
def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes):
|
||||
@app.route(
|
||||
"/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>"
|
||||
)
|
||||
def show_episode(
|
||||
dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes
|
||||
):
|
||||
repo_id = f"{dataset_namespace}/{dataset_name}"
|
||||
try:
|
||||
if dataset is None:
|
||||
@@ -150,7 +158,9 @@ def run_server(
|
||||
400,
|
||||
)
|
||||
dataset_version = (
|
||||
str(dataset.meta._version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
|
||||
str(dataset.meta._version)
|
||||
if isinstance(dataset, LeRobotDataset)
|
||||
else dataset.codebase_version
|
||||
)
|
||||
match = re.search(r"v(\d+)\.", dataset_version)
|
||||
if match:
|
||||
@@ -158,7 +168,9 @@ def run_server(
|
||||
if major_version < 2:
|
||||
return "Make sure to convert your LeRobotDataset to v2 & above."
|
||||
|
||||
episode_data_csv_str, columns, ignored_columns = get_episode_data(dataset, episode_id)
|
||||
episode_data_csv_str, columns, ignored_columns = get_episode_data(
|
||||
dataset, episode_id
|
||||
)
|
||||
dataset_info = {
|
||||
"repo_id": f"{dataset_namespace}/{dataset_name}",
|
||||
"num_samples": dataset.num_frames
|
||||
@@ -171,18 +183,23 @@ def run_server(
|
||||
}
|
||||
if isinstance(dataset, LeRobotDataset):
|
||||
video_paths = [
|
||||
dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys
|
||||
dataset.meta.get_video_file_path(episode_id, key)
|
||||
for key in dataset.meta.video_keys
|
||||
]
|
||||
videos_info = [
|
||||
{
|
||||
"url": url_for("static", filename=str(video_path).replace("\\", "/")),
|
||||
"url": url_for(
|
||||
"static", filename=str(video_path).replace("\\", "/")
|
||||
),
|
||||
"filename": video_path.parent.name,
|
||||
}
|
||||
for video_path in video_paths
|
||||
]
|
||||
tasks = dataset.meta.episodes[episode_id]["tasks"]
|
||||
else:
|
||||
video_keys = [key for key, ft in dataset.features.items() if ft["dtype"] == "video"]
|
||||
video_keys = [
|
||||
key for key, ft in dataset.features.items() if ft["dtype"] == "video"
|
||||
]
|
||||
videos_info = [
|
||||
{
|
||||
"url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
|
||||
@@ -197,20 +214,29 @@ def run_server(
|
||||
]
|
||||
|
||||
response = requests.get(
|
||||
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl", timeout=5
|
||||
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl",
|
||||
timeout=5,
|
||||
)
|
||||
response.raise_for_status()
|
||||
# Split into lines and parse each line as JSON
|
||||
tasks_jsonl = [json.loads(line) for line in response.text.splitlines() if line.strip()]
|
||||
tasks_jsonl = [
|
||||
json.loads(line) for line in response.text.splitlines() if line.strip()
|
||||
]
|
||||
|
||||
filtered_tasks_jsonl = [row for row in tasks_jsonl if row["episode_index"] == episode_id]
|
||||
filtered_tasks_jsonl = [
|
||||
row for row in tasks_jsonl if row["episode_index"] == episode_id
|
||||
]
|
||||
tasks = filtered_tasks_jsonl[0]["tasks"]
|
||||
|
||||
videos_info[0]["language_instruction"] = tasks
|
||||
|
||||
if episodes is None:
|
||||
episodes = list(
|
||||
range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes)
|
||||
range(
|
||||
dataset.num_episodes
|
||||
if isinstance(dataset, LeRobotDataset)
|
||||
else dataset.total_episodes
|
||||
)
|
||||
)
|
||||
|
||||
return render_template(
|
||||
@@ -237,7 +263,11 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
|
||||
This file will be loaded by Dygraph javascript to plot data in real time."""
|
||||
columns = []
|
||||
|
||||
selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] in ["float32", "int32"]]
|
||||
selected_columns = [
|
||||
col
|
||||
for col, ft in dataset.features.items()
|
||||
if ft["dtype"] in ["float32", "int32"]
|
||||
]
|
||||
selected_columns.remove("timestamp")
|
||||
|
||||
ignored_columns = []
|
||||
@@ -258,7 +288,10 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
|
||||
else dataset.features[column_name].shape[0]
|
||||
)
|
||||
|
||||
if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]:
|
||||
if (
|
||||
"names" in dataset.features[column_name]
|
||||
and dataset.features[column_name]["names"]
|
||||
):
|
||||
column_names = dataset.features[column_name]["names"]
|
||||
while not isinstance(column_names, list):
|
||||
column_names = list(column_names.values())[0]
|
||||
@@ -281,8 +314,12 @@ def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index)
|
||||
else:
|
||||
repo_id = dataset.repo_id
|
||||
|
||||
url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format(
|
||||
episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index
|
||||
url = (
|
||||
f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
|
||||
+ dataset.data_path.format(
|
||||
episode_chunk=int(episode_index) // dataset.chunks_size,
|
||||
episode_index=episode_index,
|
||||
)
|
||||
)
|
||||
df = pd.read_parquet(url)
|
||||
data = df[selected_columns] # Select specific columns
|
||||
@@ -315,7 +352,9 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
|
||||
]
|
||||
|
||||
|
||||
def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
||||
def get_episode_language_instruction(
|
||||
dataset: LeRobotDataset, ep_index: int
|
||||
) -> list[str]:
|
||||
# check if the dataset has language instructions
|
||||
if "language_instruction" not in dataset.features:
|
||||
return None
|
||||
@@ -326,12 +365,15 @@ def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) ->
|
||||
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
|
||||
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
|
||||
# with the tf.tensor appearing in the string
|
||||
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)")
|
||||
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix(
|
||||
"', shape=(), dtype=string)"
|
||||
)
|
||||
|
||||
|
||||
def get_dataset_info(repo_id: str) -> IterableNamespace:
|
||||
response = requests.get(
|
||||
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json", timeout=5
|
||||
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json",
|
||||
timeout=5,
|
||||
)
|
||||
response.raise_for_status() # Raises an HTTPError for bad responses
|
||||
dataset_info = response.json()
|
||||
@@ -361,7 +403,9 @@ def visualize_dataset_html(
|
||||
if force_override:
|
||||
shutil.rmtree(output_dir)
|
||||
else:
|
||||
logging.info(f"Output directory already exists. Loading from it: '{output_dir}'")
|
||||
logging.info(
|
||||
f"Output directory already exists. Loading from it: '{output_dir}'"
|
||||
)
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user