75 lines
1.9 KiB
Python
75 lines
1.9 KiB
Python
import io
|
|
import logging
|
|
import random
|
|
import zipfile
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import requests
|
|
import torch
|
|
import tqdm
|
|
|
|
|
|
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
|
print(f"downloading from {url}")
|
|
response = requests.get(url, stream=True)
|
|
if response.status_code == 200:
|
|
total_size = int(response.headers.get("content-length", 0))
|
|
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
|
|
|
|
zip_file = io.BytesIO()
|
|
for chunk in response.iter_content(chunk_size=1024):
|
|
if chunk:
|
|
zip_file.write(chunk)
|
|
progress_bar.update(len(chunk))
|
|
|
|
progress_bar.close()
|
|
|
|
zip_file.seek(0)
|
|
|
|
with zipfile.ZipFile(zip_file, "r") as zip_ref:
|
|
zip_ref.extractall(destination_folder)
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
|
|
def set_seed(seed):
|
|
"""Set seed for reproducibility."""
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
def init_logging():
|
|
def custom_format(record):
|
|
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
fnameline = f"{record.pathname}:{record.lineno}"
|
|
message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}"
|
|
return message
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
for handler in logging.root.handlers[:]:
|
|
logging.root.removeHandler(handler)
|
|
|
|
formatter = logging.Formatter()
|
|
formatter.format = custom_format
|
|
console_handler = logging.StreamHandler()
|
|
console_handler.setFormatter(formatter)
|
|
logging.getLogger().addHandler(console_handler)
|
|
|
|
|
|
def format_big_number(num):
|
|
suffixes = ["", "K", "M", "B", "T", "Q"]
|
|
divisor = 1000.0
|
|
|
|
for suffix in suffixes:
|
|
if abs(num) < divisor:
|
|
return f"{num:.0f}{suffix}"
|
|
num /= divisor
|
|
|
|
return num
|