import io import logging import random import zipfile from datetime import datetime from pathlib import Path import numpy as np import requests import torch import tqdm def download_and_extract_zip(url: str, destination_folder: Path) -> bool: print(f"downloading from {url}") response = requests.get(url, stream=True) if response.status_code == 200: total_size = int(response.headers.get("content-length", 0)) progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True) zip_file = io.BytesIO() for chunk in response.iter_content(chunk_size=1024): if chunk: zip_file.write(chunk) progress_bar.update(len(chunk)) progress_bar.close() zip_file.seek(0) with zipfile.ZipFile(zip_file, "r") as zip_ref: zip_ref.extractall(destination_folder) return True else: return False def set_seed(seed): """Set seed for reproducibility.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def init_logging(): def custom_format(record): dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") fnameline = f"{record.pathname}:{record.lineno}" message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}" return message logging.basicConfig(level=logging.INFO) for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) formatter = logging.Formatter() formatter.format = custom_format console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) logging.getLogger().addHandler(console_handler) def format_big_number(num): suffixes = ["", "K", "M", "B", "T", "Q"] divisor = 1000.0 for suffix in suffixes: if abs(num) < divisor: return f"{num:.0f}{suffix}" num /= divisor return num