forked from tangger/lerobot
Provide more information to the user (#358)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com> Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
@@ -70,7 +70,13 @@ from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.io_utils import write_video
|
||||
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
||||
from lerobot.common.utils.utils import (
|
||||
get_safe_torch_device,
|
||||
init_hydra_config,
|
||||
init_logging,
|
||||
inside_slurm,
|
||||
set_global_seed,
|
||||
)
|
||||
|
||||
|
||||
def rollout(
|
||||
@@ -79,7 +85,6 @@ def rollout(
|
||||
seeds: list[int] | None = None,
|
||||
return_observations: bool = False,
|
||||
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
|
||||
enable_progbar: bool = False,
|
||||
) -> dict:
|
||||
"""Run a batched policy rollout once through a batch of environments.
|
||||
|
||||
@@ -109,7 +114,6 @@ def rollout(
|
||||
are returned optionally because they typically take more memory to cache. Defaults to False.
|
||||
render_callback: Optional rendering callback to be used after the environments are reset, and after
|
||||
every step.
|
||||
enable_progbar: Enable a progress bar over rollout steps.
|
||||
Returns:
|
||||
The dictionary described above.
|
||||
"""
|
||||
@@ -136,7 +140,7 @@ def rollout(
|
||||
progbar = trange(
|
||||
max_steps,
|
||||
desc=f"Running rollout with at most {max_steps} steps",
|
||||
disable=not enable_progbar,
|
||||
disable=inside_slurm(), # we dont want progress bar when we use slurm, since it clutters the logs
|
||||
leave=False,
|
||||
)
|
||||
while not np.all(done):
|
||||
@@ -210,8 +214,6 @@ def eval_policy(
|
||||
videos_dir: Path | None = None,
|
||||
return_episode_data: bool = False,
|
||||
start_seed: int | None = None,
|
||||
enable_progbar: bool = False,
|
||||
enable_inner_progbar: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Args:
|
||||
@@ -224,8 +226,6 @@ def eval_policy(
|
||||
the "episodes" key of the returned dictionary.
|
||||
start_seed: The first seed to use for the first individual rollout. For all subsequent rollouts the
|
||||
seed is incremented by 1. If not provided, the environments are not manually seeded.
|
||||
enable_progbar: Enable progress bar over batches.
|
||||
enable_inner_progbar: Enable progress bar over steps in each batch.
|
||||
Returns:
|
||||
Dictionary with metrics and data regarding the rollouts.
|
||||
"""
|
||||
@@ -266,7 +266,8 @@ def eval_policy(
|
||||
if return_episode_data:
|
||||
episode_data: dict | None = None
|
||||
|
||||
progbar = trange(n_batches, desc="Stepping through eval batches", disable=not enable_progbar)
|
||||
# we dont want progress bar when we use slurm, since it clutters the logs
|
||||
progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm())
|
||||
for batch_ix in progbar:
|
||||
# Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout
|
||||
# step.
|
||||
@@ -285,7 +286,6 @@ def eval_policy(
|
||||
seeds=list(seeds) if seeds else None,
|
||||
return_observations=return_episode_data,
|
||||
render_callback=render_frame if max_episodes_rendered > 0 else None,
|
||||
enable_progbar=enable_inner_progbar,
|
||||
)
|
||||
|
||||
# Figure out where in each rollout sequence the first done condition was encountered (results after
|
||||
@@ -487,8 +487,6 @@ def main(
|
||||
max_episodes_rendered=10,
|
||||
videos_dir=Path(out_dir) / "videos",
|
||||
start_seed=hydra_cfg.seed,
|
||||
enable_progbar=True,
|
||||
enable_inner_progbar=True,
|
||||
)
|
||||
print(info["aggregated"])
|
||||
|
||||
|
||||
@@ -241,6 +241,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
raise NotImplementedError()
|
||||
|
||||
init_logging()
|
||||
logging.info(pformat(OmegaConf.to_container(cfg)))
|
||||
|
||||
if cfg.training.online_steps > 0 and isinstance(cfg.dataset_repo_id, ListConfig):
|
||||
raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.")
|
||||
|
||||
Reference in New Issue
Block a user