rich annotations & update open-pi fsdp explanations

This commit is contained in:
Leon998
2026-03-18 13:59:52 +08:00
parent 814f3c3526
commit 4934c4794e
11 changed files with 349 additions and 32 deletions

View File

@@ -27,6 +27,10 @@ class BaseLogger(ABC):
self.object_data_logger: Dict[str, List[Any]] = {}
self.color_image_logger: Dict[str, List[Any]] = {}
self.depth_image_logger: Dict[str, List[Any]] = {}
self.seg_image_logger: Dict[str, List[Any]] = {}
self.color_image_step_logger: Dict[str, List[Any]] = {}
self.depth_image_step_logger: Dict[str, List[Any]] = {}
self.seg_image_step_logger: Dict[str, List[Any]] = {}
def update_tpi_initial_info(self, tpi_initial_info):
self.tpi_initial_info = tpi_initial_info
@@ -67,17 +71,50 @@ class BaseLogger(ABC):
self.scalar_data_logger[robot][key] = []
self.scalar_data_logger[robot][key].append(value)
def add_color_image(self, robot, key, value):
if robot not in self.color_image_logger:
self.color_image_logger[robot] = {}
if key not in self.color_image_logger[robot]:
self.color_image_logger[robot][key] = []
self.color_image_logger[robot][key].append(value)
def _add_image_data(self, data_logger, step_logger, robot, key, value, step_idx=None):
if robot not in data_logger:
data_logger[robot] = {}
if key not in data_logger[robot]:
data_logger[robot][key] = []
data_logger[robot][key].append(value)
# def add_depth_image(self, key, value):
# if key not in self.depth_image_logger:
# self.depth_image_logger[key] = []
# self.depth_image_logger[key].append(value)
if robot not in step_logger:
step_logger[robot] = {}
if key not in step_logger[robot]:
step_logger[robot][key] = []
if step_idx is None:
step_idx = len(data_logger[robot][key]) - 1
step_logger[robot][key].append(int(step_idx))
def add_color_image(self, robot, key, value, step_idx=None):
self._add_image_data(
self.color_image_logger,
self.color_image_step_logger,
robot,
key,
value,
step_idx=step_idx,
)
def add_depth_image(self, robot, key, value, step_idx=None):
self._add_image_data(
self.depth_image_logger,
self.depth_image_step_logger,
robot,
key,
value,
step_idx=step_idx,
)
def add_seg_image(self, robot, key, value, step_idx=None):
self._add_image_data(
self.seg_image_logger,
self.seg_image_step_logger,
robot,
key,
value,
step_idx=step_idx,
)
def clear(
self,
@@ -97,6 +134,10 @@ class BaseLogger(ABC):
self.scalar_data_logger = {}
self.color_image_logger = {}
self.depth_image_logger = {}
self.seg_image_logger = {}
self.color_image_step_logger = {}
self.depth_image_step_logger = {}
self.seg_image_step_logger = {}
@abstractmethod
def close(self):

View File

@@ -13,6 +13,16 @@ from tqdm import tqdm
DEFAULT_RGB_SCALE_FACTOR = 256000.0
def float_array_to_uint16_png(float_array):
array = np.nan_to_num(float_array, nan=0.0, posinf=0.0, neginf=0.0)
array = np.round(array * 10000.0)
array = np.clip(array, 0, 65535)
return array.astype(np.uint16)
def seg_array_to_uint16_png(seg_array):
array = np.nan_to_num(seg_array, nan=0.0, posinf=0.0, neginf=0.0)
array = np.clip(array, 0, 65535)
return array.astype(np.uint16)
# pylint: disable=line-too-long,unused-argument
class LmdbLogger(BaseLogger):
@@ -63,6 +73,7 @@ class LmdbLogger(BaseLogger):
meta_info["tpi_initial_info"] = self.tpi_initial_info
meta_info["collect_info"] = self.collect_info
meta_info["version"] = self.version
meta_info["image_valid_step_ids"] = {}
# Lmdb
log_path_lmdb = save_dir / "lmdb"
@@ -139,13 +150,20 @@ class LmdbLogger(BaseLogger):
# Save color images
if save_img:
for key, value in self.color_image_logger[robot_name].items():
for key, value in self.color_image_logger.get(robot_name, {}).items():
root_img_path = save_dir / f"{key}"
root_img_path.mkdir(parents=True, exist_ok=True)
step_ids = self.color_image_step_logger.get(robot_name, {}).get(key, [])
if len(step_ids) != len(value):
step_ids = list(range(len(value)))
else:
step_ids = [int(x) for x in step_ids]
meta_info["image_valid_step_ids"][key] = step_ids
meta_info["keys"][key] = []
for i, image in enumerate(tqdm(value)):
step_id = str(i).zfill(4)
step_id = str(step_ids[i]).zfill(4)
txn.put(
f"{key}/{step_id}".encode("utf-8"),
pickle.dumps(cv2.imencode(".jpg", image.astype(np.uint8))[1]),
@@ -154,7 +172,49 @@ class LmdbLogger(BaseLogger):
imageio.mimsave(os.path.join(root_img_path, "demo.mp4"), value, fps=15)
meta_info["num_steps"] = len(value)
for key, value in self.depth_image_logger.get(robot_name, {}).items():
root_img_path = save_dir / f"{key}"
root_img_path.mkdir(parents=True, exist_ok=True)
step_ids = self.depth_image_step_logger.get(robot_name, {}).get(key, [])
if len(step_ids) != len(value):
step_ids = list(range(len(value)))
else:
step_ids = [int(x) for x in step_ids]
meta_info["image_valid_step_ids"][key] = step_ids
meta_info["keys"][key] = []
for i, image in enumerate(tqdm(value)):
step_id = str(step_ids[i]).zfill(4)
depth_image = float_array_to_uint16_png(np.asarray(image))
txn.put(
f"{key}/{step_id}".encode('utf-8'),
pickle.dumps(cv2.imencode('.png', depth_image)[1])
)
meta_info["keys"][key].append(f"{key}/{step_id}".encode('utf-8'))
for key, value in self.seg_image_logger.get(robot_name, {}).items():
root_img_path = save_dir / f"{key}"
root_img_path.mkdir(parents=True, exist_ok=True)
step_ids = self.seg_image_step_logger.get(robot_name, {}).get(key, [])
if len(step_ids) != len(value):
step_ids = list(range(len(value)))
else:
step_ids = [int(x) for x in step_ids]
meta_info["image_valid_step_ids"][key] = step_ids
meta_info["keys"][key] = []
for i, image in enumerate(tqdm(value)):
step_id = str(step_ids[i]).zfill(4)
seg_image = seg_array_to_uint16_png(np.asarray(image))
txn.put(
f"{key}/{step_id}".encode('utf-8'),
pickle.dumps(cv2.imencode('.png', seg_image)[1])
)
meta_info["keys"][key].append(f"{key}/{step_id}".encode('utf-8'))
meta_info["num_steps"] = self.log_num_steps
txn.commit()
lmdb_env.close()
pickle.dump(meta_info, open(os.path.join(save_dir, "meta_info.pkl"), "wb"))