forked from tangger/lerobot
Provide more information to the user (#358)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com> Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
@@ -27,6 +28,12 @@ import torch
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
def inside_slurm():
|
||||
"""Check whether the python process was launched through slurm"""
|
||||
# TODO(rcadene): return False for interactive mode `--pty bash`
|
||||
return "SLURM_JOB_ID" in os.environ
|
||||
|
||||
|
||||
def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
|
||||
"""Given a string, return a torch.device with checks on whether the device is available."""
|
||||
match cfg_device:
|
||||
@@ -158,7 +165,15 @@ def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> D
|
||||
version_base="1.2",
|
||||
)
|
||||
cfg = hydra.compose(Path(config_path).stem, overrides)
|
||||
|
||||
if cfg.eval.batch_size > cfg.eval.n_episodes:
|
||||
raise ValueError(
|
||||
"The eval batch size is greater than the number of eval episodes "
|
||||
f"({cfg.eval.batch_size} > {cfg.eval.n_episodes}). As a result, {cfg.eval.batch_size} "
|
||||
f"eval environments will be instantiated, but only {cfg.eval.n_episodes} will be used. "
|
||||
"This might significantly slow down evaluation. To fix this, you should update your command "
|
||||
f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={cfg.eval.batch_size}`), "
|
||||
f"or lower the batch size (e.g. `eval.batch_size={cfg.eval.n_episodes}`)."
|
||||
)
|
||||
return cfg
|
||||
|
||||
|
||||
|
||||
@@ -120,7 +120,7 @@ eval:
|
||||
# `batch_size` specifies the number of environments to use in a gym.vector.VectorEnv.
|
||||
batch_size: 1
|
||||
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
||||
use_async_envs: false
|
||||
use_async_envs: true
|
||||
|
||||
wandb:
|
||||
enable: false
|
||||
|
||||
5
lerobot/configs/env/aloha.yaml
vendored
5
lerobot/configs/env/aloha.yaml
vendored
@@ -2,6 +2,11 @@
|
||||
|
||||
fps: 50
|
||||
|
||||
eval:
|
||||
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
||||
# set it to false to avoid some problems of the aloha env
|
||||
use_async_envs: false
|
||||
|
||||
env:
|
||||
name: aloha
|
||||
task: AlohaInsertion-v0
|
||||
|
||||
5
lerobot/configs/env/xarm.yaml
vendored
5
lerobot/configs/env/xarm.yaml
vendored
@@ -2,6 +2,11 @@
|
||||
|
||||
fps: 15
|
||||
|
||||
eval:
|
||||
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
||||
# set it to false to avoid some problems of the aloha env
|
||||
use_async_envs: false
|
||||
|
||||
env:
|
||||
name: xarm
|
||||
task: XarmLift-v0
|
||||
|
||||
@@ -70,7 +70,13 @@ from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.io_utils import write_video
|
||||
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
||||
from lerobot.common.utils.utils import (
|
||||
get_safe_torch_device,
|
||||
init_hydra_config,
|
||||
init_logging,
|
||||
inside_slurm,
|
||||
set_global_seed,
|
||||
)
|
||||
|
||||
|
||||
def rollout(
|
||||
@@ -79,7 +85,6 @@ def rollout(
|
||||
seeds: list[int] | None = None,
|
||||
return_observations: bool = False,
|
||||
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
|
||||
enable_progbar: bool = False,
|
||||
) -> dict:
|
||||
"""Run a batched policy rollout once through a batch of environments.
|
||||
|
||||
@@ -109,7 +114,6 @@ def rollout(
|
||||
are returned optionally because they typically take more memory to cache. Defaults to False.
|
||||
render_callback: Optional rendering callback to be used after the environments are reset, and after
|
||||
every step.
|
||||
enable_progbar: Enable a progress bar over rollout steps.
|
||||
Returns:
|
||||
The dictionary described above.
|
||||
"""
|
||||
@@ -136,7 +140,7 @@ def rollout(
|
||||
progbar = trange(
|
||||
max_steps,
|
||||
desc=f"Running rollout with at most {max_steps} steps",
|
||||
disable=not enable_progbar,
|
||||
disable=inside_slurm(), # we dont want progress bar when we use slurm, since it clutters the logs
|
||||
leave=False,
|
||||
)
|
||||
while not np.all(done):
|
||||
@@ -210,8 +214,6 @@ def eval_policy(
|
||||
videos_dir: Path | None = None,
|
||||
return_episode_data: bool = False,
|
||||
start_seed: int | None = None,
|
||||
enable_progbar: bool = False,
|
||||
enable_inner_progbar: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Args:
|
||||
@@ -224,8 +226,6 @@ def eval_policy(
|
||||
the "episodes" key of the returned dictionary.
|
||||
start_seed: The first seed to use for the first individual rollout. For all subsequent rollouts the
|
||||
seed is incremented by 1. If not provided, the environments are not manually seeded.
|
||||
enable_progbar: Enable progress bar over batches.
|
||||
enable_inner_progbar: Enable progress bar over steps in each batch.
|
||||
Returns:
|
||||
Dictionary with metrics and data regarding the rollouts.
|
||||
"""
|
||||
@@ -266,7 +266,8 @@ def eval_policy(
|
||||
if return_episode_data:
|
||||
episode_data: dict | None = None
|
||||
|
||||
progbar = trange(n_batches, desc="Stepping through eval batches", disable=not enable_progbar)
|
||||
# we dont want progress bar when we use slurm, since it clutters the logs
|
||||
progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm())
|
||||
for batch_ix in progbar:
|
||||
# Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout
|
||||
# step.
|
||||
@@ -285,7 +286,6 @@ def eval_policy(
|
||||
seeds=list(seeds) if seeds else None,
|
||||
return_observations=return_episode_data,
|
||||
render_callback=render_frame if max_episodes_rendered > 0 else None,
|
||||
enable_progbar=enable_inner_progbar,
|
||||
)
|
||||
|
||||
# Figure out where in each rollout sequence the first done condition was encountered (results after
|
||||
@@ -487,8 +487,6 @@ def main(
|
||||
max_episodes_rendered=10,
|
||||
videos_dir=Path(out_dir) / "videos",
|
||||
start_seed=hydra_cfg.seed,
|
||||
enable_progbar=True,
|
||||
enable_inner_progbar=True,
|
||||
)
|
||||
print(info["aggregated"])
|
||||
|
||||
|
||||
@@ -241,6 +241,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
raise NotImplementedError()
|
||||
|
||||
init_logging()
|
||||
logging.info(pformat(OmegaConf.to_container(cfg)))
|
||||
|
||||
if cfg.training.online_steps > 0 and isinstance(cfg.dataset_repo_id, ListConfig):
|
||||
raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.")
|
||||
|
||||
Reference in New Issue
Block a user