Merge remote-tracking branch 'Cadene/user/rcadene/2024_03_31_remove_torchrl' into refactor_act_remove_torchrl

This commit is contained in:
Alexander Soare
2024-04-05 12:00:31 +01:00
3 changed files with 28 additions and 2 deletions

View File

@@ -203,3 +203,12 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
torch.save(stats, stats_path)
return stats
def cycle(iterable):
iterator = iter(iterable)
while True:
try:
yield next(iterator)
except StopIteration:
iterator = iter(iterable)

View File

@@ -95,3 +95,15 @@ def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> D
)
cfg = hydra.compose(Path(config_path).stem, overrides)
return cfg
def print_cuda_memory_usage():
import gc
gc.collect()
# Also clear the cache if you want to fully release the memory
torch.cuda.empty_cache()
print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2))
print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2))
print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2))
print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2))