Add test_image_writer, accept PIL images, improve ImageWriter perf in main process

This commit is contained in:
Simon Alibert
2024-11-02 20:00:07 +01:00
parent 375abd3020
commit 6b2ec1ed77
4 changed files with 426 additions and 16 deletions

View File

@@ -19,8 +19,8 @@ import threading
from pathlib import Path
import numpy as np
import PIL.Image
import torch
from PIL import Image
DEFAULT_IMAGE_PATH = "{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
@@ -40,10 +40,27 @@ def safe_stop_image_writer(func):
return wrapper
def write_image(image_array: np.ndarray, fpath: Path):
def image_array_to_image(image_array: np.ndarray) -> PIL.Image.Image:
# TODO(aliberts): handle 1 channel and 4 for depth images
if image_array.ndim == 3 and image_array.shape[0] in [1, 3]:
# Transpose from pytorch convention (C, H, W) to (H, W, C)
image_array = image_array.transpose(1, 2, 0)
if image_array.dtype != np.uint8:
# Assume the image is in [0, 1] range for floating-point data
image_array = np.clip(image_array, 0, 1)
image_array = (image_array * 255).astype(np.uint8)
return PIL.Image.fromarray(image_array)
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
try:
image = Image.fromarray(image_array)
image.save(fpath)
if isinstance(image, np.ndarray):
img = image_array_to_image(image)
elif isinstance(image, PIL.Image.Image):
img = image
else:
raise TypeError(f"Unsupported image type: {type(image)}")
img.save(fpath)
except Exception as e:
print(f"Error writing image {fpath}: {e}")
@@ -63,7 +80,6 @@ def worker_process(queue: queue.Queue, num_threads: int):
threads = []
for _ in range(num_threads):
t = threading.Thread(target=worker_thread_process, args=(queue,))
t.daemon = True
t.start()
threads.append(t)
for t in threads:
@@ -95,6 +111,10 @@ class ImageWriter:
self.queue = None
self.threads = []
self.processes = []
self._stopped = False
if num_threads <= 0 and num_processes <= 0:
raise ValueError("Number of threads and processes must be greater than zero.")
if self.num_processes == 0:
# Use threading
@@ -109,7 +129,6 @@ class ImageWriter:
self.queue = multiprocessing.JoinableQueue()
for _ in range(self.num_processes):
p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads))
p.daemon = True
p.start()
self.processes.append(p)
@@ -124,27 +143,33 @@ class ImageWriter:
episode_index=episode_index, image_key=image_key, frame_index=0
).parent
def save_image(self, image_array: torch.Tensor | np.ndarray, fpath: Path):
if isinstance(image_array, torch.Tensor):
image_array = image_array.numpy()
self.queue.put((image_array, fpath))
def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path):
if isinstance(image, torch.Tensor):
# Convert tensor to numpy array to minimize main process time
image = image.cpu().numpy()
self.queue.put((image, fpath))
def wait_until_done(self):
self.queue.join()
def stop(self):
if self._stopped:
return
if self.num_processes == 0:
# For threading
for _ in self.threads:
self.queue.put(None)
for t in self.threads:
t.join()
else:
# For multiprocessing
num_nones = self.num_processes * self.num_threads
for _ in range(num_nones):
self.queue.put(None)
self.queue.close()
self.queue.join_thread()
for p in self.processes:
p.join()
if p.is_alive():
p.terminate()
self.queue.close()
self.queue.join_thread()
self._stopped = True

View File

@@ -590,7 +590,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
img_path.parent.mkdir(parents=True, exist_ok=True)
self.image_writer.save_image(
image_array=frame[cam_key],
image=frame[cam_key],
fpath=img_path,
)