forked from tangger/lerobot
Add test_image_writer, accept PIL images, improve ImageWriter perf in main process
This commit is contained in:
@@ -19,8 +19,8 @@ import threading
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
DEFAULT_IMAGE_PATH = "{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
||||
|
||||
@@ -40,10 +40,27 @@ def safe_stop_image_writer(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def write_image(image_array: np.ndarray, fpath: Path):
|
||||
def image_array_to_image(image_array: np.ndarray) -> PIL.Image.Image:
|
||||
# TODO(aliberts): handle 1 channel and 4 for depth images
|
||||
if image_array.ndim == 3 and image_array.shape[0] in [1, 3]:
|
||||
# Transpose from pytorch convention (C, H, W) to (H, W, C)
|
||||
image_array = image_array.transpose(1, 2, 0)
|
||||
if image_array.dtype != np.uint8:
|
||||
# Assume the image is in [0, 1] range for floating-point data
|
||||
image_array = np.clip(image_array, 0, 1)
|
||||
image_array = (image_array * 255).astype(np.uint8)
|
||||
return PIL.Image.fromarray(image_array)
|
||||
|
||||
|
||||
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
|
||||
try:
|
||||
image = Image.fromarray(image_array)
|
||||
image.save(fpath)
|
||||
if isinstance(image, np.ndarray):
|
||||
img = image_array_to_image(image)
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
img = image
|
||||
else:
|
||||
raise TypeError(f"Unsupported image type: {type(image)}")
|
||||
img.save(fpath)
|
||||
except Exception as e:
|
||||
print(f"Error writing image {fpath}: {e}")
|
||||
|
||||
@@ -63,7 +80,6 @@ def worker_process(queue: queue.Queue, num_threads: int):
|
||||
threads = []
|
||||
for _ in range(num_threads):
|
||||
t = threading.Thread(target=worker_thread_process, args=(queue,))
|
||||
t.daemon = True
|
||||
t.start()
|
||||
threads.append(t)
|
||||
for t in threads:
|
||||
@@ -95,6 +111,10 @@ class ImageWriter:
|
||||
self.queue = None
|
||||
self.threads = []
|
||||
self.processes = []
|
||||
self._stopped = False
|
||||
|
||||
if num_threads <= 0 and num_processes <= 0:
|
||||
raise ValueError("Number of threads and processes must be greater than zero.")
|
||||
|
||||
if self.num_processes == 0:
|
||||
# Use threading
|
||||
@@ -109,7 +129,6 @@ class ImageWriter:
|
||||
self.queue = multiprocessing.JoinableQueue()
|
||||
for _ in range(self.num_processes):
|
||||
p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads))
|
||||
p.daemon = True
|
||||
p.start()
|
||||
self.processes.append(p)
|
||||
|
||||
@@ -124,27 +143,33 @@ class ImageWriter:
|
||||
episode_index=episode_index, image_key=image_key, frame_index=0
|
||||
).parent
|
||||
|
||||
def save_image(self, image_array: torch.Tensor | np.ndarray, fpath: Path):
|
||||
if isinstance(image_array, torch.Tensor):
|
||||
image_array = image_array.numpy()
|
||||
self.queue.put((image_array, fpath))
|
||||
def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path):
|
||||
if isinstance(image, torch.Tensor):
|
||||
# Convert tensor to numpy array to minimize main process time
|
||||
image = image.cpu().numpy()
|
||||
self.queue.put((image, fpath))
|
||||
|
||||
def wait_until_done(self):
|
||||
self.queue.join()
|
||||
|
||||
def stop(self):
|
||||
if self._stopped:
|
||||
return
|
||||
|
||||
if self.num_processes == 0:
|
||||
# For threading
|
||||
for _ in self.threads:
|
||||
self.queue.put(None)
|
||||
for t in self.threads:
|
||||
t.join()
|
||||
else:
|
||||
# For multiprocessing
|
||||
num_nones = self.num_processes * self.num_threads
|
||||
for _ in range(num_nones):
|
||||
self.queue.put(None)
|
||||
self.queue.close()
|
||||
self.queue.join_thread()
|
||||
for p in self.processes:
|
||||
p.join()
|
||||
if p.is_alive():
|
||||
p.terminate()
|
||||
self.queue.close()
|
||||
self.queue.join_thread()
|
||||
|
||||
self._stopped = True
|
||||
|
||||
@@ -590,7 +590,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
img_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.image_writer.save_image(
|
||||
image_array=frame[cam_key],
|
||||
image=frame[cam_key],
|
||||
fpath=img_path,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user