[HIL-SERL] Review feedback modifications (#1112)

This commit is contained in:
Adil Zouitine
2025-05-15 15:24:41 +02:00
committed by GitHub
parent c7a3973653
commit 2051dd38fc
17 changed files with 504 additions and 180 deletions

View File

@@ -19,9 +19,10 @@ import os.path as osp
import platform
import subprocess
import time
from copy import copy
from copy import copy, deepcopy
from datetime import datetime, timezone
from pathlib import Path
from statistics import mean
import numpy as np
import torch
@@ -108,11 +109,14 @@ def is_amp_available(device: str):
raise ValueError(f"Unknown device '{device}.")
def init_logging(log_file=None):
def init_logging(log_file: Path | None = None, display_pid: bool = False):
def custom_format(record):
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
fnameline = f"{record.pathname}:{record.lineno}"
message = f"{record.levelname} [PID: {os.getpid()}] {dt} {fnameline[-15:]:>15} {record.msg}"
# NOTE: Display PID is useful for multi-process logging.
pid_str = f"[PID: {os.getpid()}]" if display_pid else ""
message = f"{record.levelname} {pid_str} {dt} {fnameline[-15:]:>15} {record.msg}"
return message
logging.basicConfig(level=logging.INFO)
@@ -238,30 +242,99 @@ def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
class TimerManager:
"""
Lightweight utility to measure elapsed time.
Examples
--------
>>> timer = TimerManager("Policy", log=False)
>>> for _ in range(3):
... with timer:
... time.sleep(0.01)
>>> print(timer.last, timer.fps_avg, timer.percentile(90))
"""
def __init__(
self,
elapsed_time_list: list[float] | None = None,
label="Elapsed time",
log=True,
label: str = "Elapsed-time",
log: bool = True,
logger: logging.Logger | None = None,
):
self.label = label
self.elapsed_time_list = elapsed_time_list
self.log = log
self.elapsed = 0.0
self.logger = logger
self._start: float | None = None
self._history: list[float] = []
def __enter__(self):
self.start = time.perf_counter()
return self.start()
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
def start(self):
self._start = time.perf_counter()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.elapsed: float = time.perf_counter() - self.start
if self.elapsed_time_list is not None:
self.elapsed_time_list.append(self.elapsed)
def stop(self) -> float:
if self._start is None:
raise RuntimeError("Timer was never started.")
elapsed = time.perf_counter() - self._start
self._history.append(elapsed)
self._start = None
if self.log:
print(f"{self.label}: {self.elapsed:.6f} seconds")
if self.logger is not None:
self.logger.info(f"{self.label}: {elapsed:.6f} s")
else:
logging.info(f"{self.label}: {elapsed:.6f} s")
return elapsed
def reset(self):
self._history.clear()
@property
def elapsed_seconds(self):
return self.elapsed
def last(self) -> float:
return self._history[-1] if self._history else 0.0
@property
def avg(self) -> float:
return mean(self._history) if self._history else 0.0
@property
def total(self) -> float:
return sum(self._history)
@property
def count(self) -> int:
return len(self._history)
@property
def history(self) -> list[float]:
return deepcopy(self._history)
@property
def fps_history(self) -> list[float]:
return [1.0 / t for t in self._history]
@property
def fps_last(self) -> float:
return 0.0 if self.last == 0 else 1.0 / self.last
@property
def fps_avg(self) -> float:
return 0.0 if self.avg == 0 else 1.0 / self.avg
def percentile(self, p: float) -> float:
"""
Return the p-th percentile of recorded times.
"""
if not self._history:
return 0.0
return float(np.percentile(self._history, p))
def fps_percentile(self, p: float) -> float:
"""
FPS corresponding to the p-th percentile time.
"""
val = self.percentile(p)
return 0.0 if val == 0 else 1.0 / val