Provide more information to the user (#358)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com> Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
@@ -27,6 +28,12 @@ import torch
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
def inside_slurm():
|
||||
"""Check whether the python process was launched through slurm"""
|
||||
# TODO(rcadene): return False for interactive mode `--pty bash`
|
||||
return "SLURM_JOB_ID" in os.environ
|
||||
|
||||
|
||||
def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
|
||||
"""Given a string, return a torch.device with checks on whether the device is available."""
|
||||
match cfg_device:
|
||||
@@ -158,7 +165,15 @@ def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> D
|
||||
version_base="1.2",
|
||||
)
|
||||
cfg = hydra.compose(Path(config_path).stem, overrides)
|
||||
|
||||
if cfg.eval.batch_size > cfg.eval.n_episodes:
|
||||
raise ValueError(
|
||||
"The eval batch size is greater than the number of eval episodes "
|
||||
f"({cfg.eval.batch_size} > {cfg.eval.n_episodes}). As a result, {cfg.eval.batch_size} "
|
||||
f"eval environments will be instantiated, but only {cfg.eval.n_episodes} will be used. "
|
||||
"This might significantly slow down evaluation. To fix this, you should update your command "
|
||||
f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={cfg.eval.batch_size}`), "
|
||||
f"or lower the batch size (e.g. `eval.batch_size={cfg.eval.n_episodes}`)."
|
||||
)
|
||||
return cfg
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user