Provide more information to the user (#358)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
Zhuoheng Li
2024-08-23 18:00:35 +08:00
committed by GitHub
parent b5ad79a7d3
commit a2592a5563
8 changed files with 76 additions and 15 deletions

View File

@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import os.path as osp
import random
from contextlib import contextmanager
@@ -27,6 +28,12 @@ import torch
from omegaconf import DictConfig
def inside_slurm():
"""Check whether the python process was launched through slurm"""
# TODO(rcadene): return False for interactive mode `--pty bash`
return "SLURM_JOB_ID" in os.environ
def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
"""Given a string, return a torch.device with checks on whether the device is available."""
match cfg_device:
@@ -158,7 +165,15 @@ def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> D
version_base="1.2",
)
cfg = hydra.compose(Path(config_path).stem, overrides)
if cfg.eval.batch_size > cfg.eval.n_episodes:
raise ValueError(
"The eval batch size is greater than the number of eval episodes "
f"({cfg.eval.batch_size} > {cfg.eval.n_episodes}). As a result, {cfg.eval.batch_size} "
f"eval environments will be instantiated, but only {cfg.eval.n_episodes} will be used. "
"This might significantly slow down evaluation. To fix this, you should update your command "
f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={cfg.eval.batch_size}`), "
f"or lower the batch size (e.g. `eval.batch_size={cfg.eval.n_episodes}`)."
)
return cfg