forked from tangger/lerobot
fix(lerobot/scripts): remove lint warnings/errors
This commit is contained in:
@@ -90,6 +90,7 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
|
||||
print("Scanning all baudrates and motor indices")
|
||||
all_baudrates = set(series_baudrate_table.values())
|
||||
motor_index = -1 # Set the motor index to an out-of-range value.
|
||||
baudrate = None
|
||||
|
||||
for baudrate in all_baudrates:
|
||||
motor_bus.set_bus_baudrate(baudrate)
|
||||
|
||||
@@ -80,7 +80,7 @@ This might require a sudo permission to allow your terminal to monitor keyboard
|
||||
|
||||
**NOTE**: You can resume/continue data recording by running the same data recording command twice.
|
||||
"""
|
||||
|
||||
# TODO(Steven): This script should be updated to use the new robot API and the new dataset API.
|
||||
import argparse
|
||||
import importlib
|
||||
import logging
|
||||
|
||||
@@ -59,7 +59,7 @@ np_version = np.__version__ if HAS_NP else "N/A"
|
||||
|
||||
torch_version = torch.__version__ if HAS_TORCH else "N/A"
|
||||
torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A"
|
||||
cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A"
|
||||
cuda_version = torch.version.cuda if HAS_TORCH and torch.version.cuda is not None else "N/A"
|
||||
|
||||
|
||||
# TODO(aliberts): refactor into an actual command `lerobot env`
|
||||
|
||||
@@ -259,6 +259,10 @@ def eval_policy(
|
||||
threads = [] # for video saving threads
|
||||
n_episodes_rendered = 0 # for saving the correct number of videos
|
||||
|
||||
video_paths: list[str] = [] # max_episodes_rendered > 0:
|
||||
ep_frames: list[np.ndarray] = [] # max_episodes_rendered > 0
|
||||
episode_data: dict | None = None # return_episode_data == True
|
||||
|
||||
# Callback for visualization.
|
||||
def render_frame(env: gym.vector.VectorEnv):
|
||||
# noqa: B023
|
||||
@@ -271,19 +275,11 @@ def eval_policy(
|
||||
# Here we must render all frames and discard any we don't need.
|
||||
ep_frames.append(np.stack(env.call("render")[:n_to_render_now]))
|
||||
|
||||
if max_episodes_rendered > 0:
|
||||
video_paths: list[str] = []
|
||||
|
||||
if return_episode_data:
|
||||
episode_data: dict | None = None
|
||||
|
||||
# we dont want progress bar when we use slurm, since it clutters the logs
|
||||
progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm())
|
||||
for batch_ix in progbar:
|
||||
# Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout
|
||||
# step.
|
||||
if max_episodes_rendered > 0:
|
||||
ep_frames: list[np.ndarray] = []
|
||||
|
||||
if start_seed is None:
|
||||
seeds = None
|
||||
@@ -320,13 +316,19 @@ def eval_policy(
|
||||
else:
|
||||
all_seeds.append(None)
|
||||
|
||||
# FIXME: episode_data is either None or it doesn't exist
|
||||
if return_episode_data:
|
||||
if episode_data is None:
|
||||
start_data_index = 0
|
||||
elif isinstance(episode_data, dict):
|
||||
start_data_index = episode_data["index"][-1].item() + 1
|
||||
else:
|
||||
start_data_index = 0
|
||||
|
||||
this_episode_data = _compile_episode_data(
|
||||
rollout_data,
|
||||
done_indices,
|
||||
start_episode_index=batch_ix * env.num_envs,
|
||||
start_data_index=(0 if episode_data is None else (episode_data["index"][-1].item() + 1)),
|
||||
start_data_index=start_data_index,
|
||||
fps=env.unwrapped.metadata["render_fps"],
|
||||
)
|
||||
if episode_data is None:
|
||||
@@ -453,6 +455,7 @@ def _compile_episode_data(
|
||||
return data_dict
|
||||
|
||||
|
||||
# TODO(Steven): [WARN] Redefining built-in 'eval'
|
||||
@parser.wrap()
|
||||
def eval_main(cfg: EvalPipelineConfig):
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
@@ -489,7 +492,7 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
print(info["aggregated"])
|
||||
|
||||
# Save info
|
||||
with open(Path(cfg.output_dir) / "eval_info.json", "w") as f:
|
||||
with open(Path(cfg.output_dir) / "eval_info.json", "w", encoding="utf-8") as f:
|
||||
json.dump(info, f, indent=2)
|
||||
|
||||
env.close()
|
||||
|
||||
@@ -53,6 +53,7 @@ import torch
|
||||
from huggingface_hub import HfApi
|
||||
from safetensors.torch import save_file
|
||||
|
||||
# TODO(Steven): #711 Broke this
|
||||
from lerobot.common.datasets.compute_stats import compute_stats
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
|
||||
@@ -89,7 +90,7 @@ def save_meta_data(
|
||||
|
||||
# save info
|
||||
info_path = meta_data_dir / "info.json"
|
||||
with open(str(info_path), "w") as f:
|
||||
with open(str(info_path), "w", encoding="utf-8") as f:
|
||||
json.dump(info, f, indent=4)
|
||||
|
||||
# save stats
|
||||
@@ -120,11 +121,11 @@ def push_dataset_card_to_hub(
|
||||
repo_id: str,
|
||||
revision: str | None,
|
||||
tags: list | None = None,
|
||||
license: str = "apache-2.0",
|
||||
dataset_license: str = "apache-2.0",
|
||||
**card_kwargs,
|
||||
):
|
||||
"""Creates and pushes a LeRobotDataset Card with appropriate tags to easily find it on the hub."""
|
||||
card = create_lerobot_dataset_card(tags=tags, license=license, **card_kwargs)
|
||||
card = create_lerobot_dataset_card(tags=tags, license=dataset_license, **card_kwargs)
|
||||
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=revision)
|
||||
|
||||
|
||||
@@ -213,6 +214,7 @@ def push_dataset_to_hub(
|
||||
encoding,
|
||||
)
|
||||
|
||||
# TODO(Steven): This doesn't seem to exist, maybe it was removed/changed recently?
|
||||
lerobot_dataset = LeRobotDataset.from_preloaded(
|
||||
repo_id=repo_id,
|
||||
hf_dataset=hf_dataset,
|
||||
|
||||
@@ -155,12 +155,14 @@ def train(cfg: TrainPipelineConfig):
|
||||
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
||||
if cfg.env is not None:
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
|
||||
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
|
||||
logging.info(f"{dataset.num_episodes=}")
|
||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
logging.info("cfg.env.task=%s", cfg.env.task)
|
||||
logging.info("cfg.steps=%s (%s)", cfg.steps, format_big_number(cfg.steps))
|
||||
logging.info("dataset.num_frames=%s (%s)", dataset.num_frames, format_big_number(dataset.num_frames))
|
||||
logging.info("dataset.num_episodes=%s", dataset.num_episodes)
|
||||
logging.info(
|
||||
"num_learnable_params=%s (%s)", num_learnable_params, format_big_number(num_learnable_params)
|
||||
)
|
||||
logging.info("num_total_params=%s (%s)", num_total_params, format_big_number(num_total_params))
|
||||
|
||||
# create dataloader for offline training
|
||||
if hasattr(cfg.policy, "drop_n_last_frames"):
|
||||
@@ -238,7 +240,7 @@ def train(cfg: TrainPipelineConfig):
|
||||
train_tracker.reset_averages()
|
||||
|
||||
if cfg.save_checkpoint and is_saving_step:
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
logging.info("Checkpoint policy after step %s", step)
|
||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
||||
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
@@ -247,7 +249,7 @@ def train(cfg: TrainPipelineConfig):
|
||||
|
||||
if cfg.env and is_eval_step:
|
||||
step_id = get_step_identifier(step, cfg.steps)
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
logging.info("Eval policy at step %s", step)
|
||||
with (
|
||||
torch.no_grad(),
|
||||
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
|
||||
|
||||
@@ -150,7 +150,7 @@ def run_server(
|
||||
400,
|
||||
)
|
||||
dataset_version = (
|
||||
str(dataset.meta._version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
|
||||
str(dataset.meta.version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
|
||||
)
|
||||
match = re.search(r"v(\d+)\.", dataset_version)
|
||||
if match:
|
||||
@@ -358,7 +358,7 @@ def visualize_dataset_html(
|
||||
if force_override:
|
||||
shutil.rmtree(output_dir)
|
||||
else:
|
||||
logging.info(f"Output directory already exists. Loading from it: '{output_dir}'")
|
||||
logging.info("Output directory already exists. Loading from it: '%s'", {output_dir})
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user