diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index 6d9c0338..2e94a9c9 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -111,35 +111,46 @@ def is_amp_available(device: str): raise ValueError(f"Unknown device '{device}.") -def init_logging(log_file: Path | None = None, display_pid: bool = False): - def custom_format(record): +def init_logging( + log_file: Path | None = None, + display_pid: bool = False, + console_level: str = "INFO", + file_level: str = "DEBUG", +): + def custom_format(record: logging.LogRecord) -> str: dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") fnameline = f"{record.pathname}:{record.lineno}" # NOTE: Display PID is useful for multi-process logging. if display_pid: pid_str = f"[PID: {os.getpid()}]" - message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.msg}" + message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.getMessage()}" else: - message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}" + message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.getMessage()}" return message - logging.basicConfig(level=logging.INFO) - - for handler in logging.root.handlers[:]: - logging.root.removeHandler(handler) - formatter = logging.Formatter() formatter.format = custom_format + + logger = logging.getLogger() + logger.setLevel(logging.NOTSET) # Set the logger to the lowest level to capture all messages + + # Remove unused default handlers + for handler in logger.handlers[:]: + logger.removeHandler(handler) + + # Write logs to console console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) - logging.getLogger().addHandler(console_handler) + console_handler.setLevel(console_level.upper()) + logger.addHandler(console_handler) + # Additionally write logs to file if log_file is not None: - # Additionally write logs to file file_handler = logging.FileHandler(log_file) file_handler.setFormatter(formatter) - logging.getLogger().addHandler(file_handler) + file_handler.setLevel(file_level.upper()) + logger.addHandler(file_handler) def format_big_number(num, precision=0):