Files
issacdataengine/tests/integration/data_comparators/simbox_comparator.py
2026-03-16 11:44:10 +00:00

399 lines
15 KiB
Python

"""
Simbox Output Comparator
This module provides functionality to compare two Simbox task output directories.
It compares both meta_info.pkl and LMDB database contents, handling different data types:
- JSON data (dict/list)
- Scalar data (numerical arrays/lists)
- Image data (encoded images)
- Proprioception data (joint states, gripper states)
- Object data (object poses and properties)
- Action data (robot actions)
"""
import argparse
import pickle
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import cv2
import lmdb
import numpy as np
class SimboxComparator:
"""Comparator for Simbox task output directories."""
def __init__(self, dir1: str, dir2: str, tolerance: float = 1e-6, image_psnr_threshold: float = 30.0):
"""
Initialize the comparator.
Args:
dir1: Path to the first output directory
dir2: Path to the second output directory
tolerance: Numerical tolerance for floating point comparisons
image_psnr_threshold: PSNR threshold (dB) for considering images as acceptable match
"""
self.dir1 = Path(dir1)
self.dir2 = Path(dir2)
self.tolerance = tolerance
self.image_psnr_threshold = image_psnr_threshold
self.mismatches = []
self.warnings = []
self.image_psnr_values: List[float] = []
def load_directory(self, directory: Path) -> Tuple[Optional[Dict], Optional[Any], Optional[Any]]:
"""Load meta_info.pkl and LMDB database from directory."""
meta_path = directory / "meta_info.pkl"
lmdb_path = directory / "lmdb"
if not directory.is_dir() or not meta_path.exists() or not lmdb_path.is_dir():
print(f"Error: '{directory}' is not a valid output directory.")
print("It must contain 'meta_info.pkl' and an 'lmdb' subdirectory.")
return None, None, None
try:
with open(meta_path, "rb") as f:
meta_info = pickle.load(f)
except Exception as e:
print(f"Error loading meta_info.pkl from {directory}: {e}")
return None, None, None
try:
env = lmdb.open(str(lmdb_path), readonly=True, lock=False, readahead=False, meminit=False)
txn = env.begin(write=False)
except Exception as e:
print(f"Error opening LMDB database at {lmdb_path}: {e}")
return None, None, None
return meta_info, txn, env
def compare_metadata(self, meta1: Dict, meta2: Dict) -> bool:
"""Compare high-level metadata."""
identical = True
if meta1.get("num_steps") != meta2.get("num_steps"):
self.mismatches.append(f"num_steps differ: {meta1.get('num_steps')} vs {meta2.get('num_steps')}")
identical = False
return identical
def get_key_categories(self, meta: Dict) -> Dict[str, set]:
"""Extract key categories from metadata."""
key_to_category = {}
for category, keys in meta.get("keys", {}).items():
for key in keys:
key_bytes = key if isinstance(key, bytes) else key.encode()
key_to_category[key_bytes] = category
return key_to_category
def compare_json_data(self, key: bytes, data1: Any, data2: Any) -> bool:
"""Compare JSON/dict/list data."""
if type(data1) != type(data2):
self.mismatches.append(f"[{key.decode()}] Type mismatch: {type(data1).__name__} vs {type(data2).__name__}")
return False
if isinstance(data1, dict):
if set(data1.keys()) != set(data2.keys()):
self.mismatches.append(f"[{key.decode()}] Dict keys differ")
return False
for k in data1.keys():
if not self.compare_json_data(key, data1[k], data2[k]):
return False
elif isinstance(data1, list):
if len(data1) != len(data2):
self.mismatches.append(f"[{key.decode()}] List length differ: {len(data1)} vs {len(data2)}")
return False
# For lists, compare sample elements to avoid excessive output
if len(data1) > 10:
sample_indices = [0, len(data1) // 2, -1]
for idx in sample_indices:
if not self.compare_json_data(key, data1[idx], data2[idx]):
return False
else:
for i, (v1, v2) in enumerate(zip(data1, data2)):
if not self.compare_json_data(key, v1, v2):
return False
else:
if data1 != data2:
self.mismatches.append(f"[{key.decode()}] Value mismatch: {data1} vs {data2}")
return False
return True
def compare_numerical_data(self, key: bytes, data1: Any, data2: Any) -> bool:
"""Compare numerical data (arrays, lists of numbers)."""
# Convert to numpy arrays for comparison
try:
if isinstance(data1, list):
arr1 = np.array(data1)
arr2 = np.array(data2)
else:
arr1 = data1
arr2 = data2
if arr1.shape != arr2.shape:
self.mismatches.append(f"[{key.decode()}] Shape mismatch: {arr1.shape} vs {arr2.shape}")
return False
if not np.allclose(arr1, arr2, rtol=self.tolerance, atol=self.tolerance):
diff = np.abs(arr1 - arr2)
max_diff = np.max(diff)
mean_diff = np.mean(diff)
self.mismatches.append(
f"[{key.decode()}] Numerical difference: max={max_diff:.6e}, mean={mean_diff:.6e}"
)
return False
except Exception as e:
self.warnings.append(f"[{key.decode()}] Error comparing numerical data: {e}")
return False
return True
def compare_image_data(self, key: bytes, data1: np.ndarray, data2: np.ndarray) -> bool:
"""Compare image data (encoded as uint8 arrays)."""
try:
# Decode images
img1 = cv2.imdecode(data1, cv2.IMREAD_UNCHANGED)
img2 = cv2.imdecode(data2, cv2.IMREAD_UNCHANGED)
if img1 is None or img2 is None:
self.warnings.append(f"[{key.decode()}] Could not decode image, using binary comparison")
return np.array_equal(data1, data2)
# Compare shapes
if img1.shape != img2.shape:
self.mismatches.append(f"[{key.decode()}] Image shape mismatch: {img1.shape} vs {img2.shape}")
return False
# Calculate PSNR for tracking average quality
if np.array_equal(img1, img2):
psnr = 100.0
else:
diff_float = img1.astype(np.float32) - img2.astype(np.float32)
mse = np.mean(diff_float ** 2)
if mse == 0:
psnr = 100.0
else:
max_pixel = 255.0
psnr = 20 * np.log10(max_pixel / np.sqrt(mse))
self.image_psnr_values.append(psnr)
try:
print(f"[{key.decode()}] PSNR: {psnr:.2f} dB")
except Exception:
print(f"[<binary key>] PSNR: {psnr:.2f} dB")
return True
except Exception as e:
self.warnings.append(f"[{key.decode()}] Error comparing image: {e}")
return False
def _save_comparison_image(self, key: bytes, img1: np.ndarray, img2: np.ndarray, diff: np.ndarray):
"""Save comparison visualization for differing images."""
try:
output_dir = Path("image_comparisons")
output_dir.mkdir(exist_ok=True)
# Normalize difference for visualization
if len(diff.shape) == 3:
diff_vis = np.clip(diff * 5, 0, 255).astype(np.uint8)
else:
diff_vis = np.clip(diff * 5, 0, 255).astype(np.uint8)
# Ensure RGB format for concatenation
def ensure_rgb(img):
if len(img.shape) == 2:
return cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
elif len(img.shape) == 3 and img.shape[2] == 4:
return cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
return img
img1_rgb = ensure_rgb(img1)
img2_rgb = ensure_rgb(img2)
diff_rgb = ensure_rgb(diff_vis)
# Concatenate horizontally
combined = np.hstack([img1_rgb, img2_rgb, diff_rgb])
# Save
safe_key = key.decode().replace("/", "_").replace("\\", "_").replace(":", "_")
output_path = output_dir / f"diff_{safe_key}.png"
cv2.imwrite(str(output_path), cv2.cvtColor(combined, cv2.COLOR_RGB2BGR))
except Exception as e:
self.warnings.append(f"Failed to save comparison image for {key.decode()}: {e}")
def compare_value(self, key: bytes, category: str, val1: bytes, val2: bytes) -> bool:
"""Compare a single key-value pair based on its category."""
# Output the category for the current key being compared
try:
print(f"[{key.decode()}] category: {category}")
except Exception:
print(f"[<binary key>] category: {category}")
if val1 is None and val2 is None:
return True
if val1 is None or val2 is None:
self.mismatches.append(f"[{key.decode()}] Key exists in one dataset but not the other")
return False
try:
data1 = pickle.loads(val1)
data2 = pickle.loads(val2)
except Exception as e:
self.warnings.append(f"[{key.decode()}] Error unpickling data: {e}")
return val1 == val2
# Route to appropriate comparison based on category
if category == "json_data":
return self.compare_json_data(key, data1, data2)
elif category in ["scalar_data", "proprio_data", "object_data", "action_data"]:
return self.compare_numerical_data(key, data1, data2)
elif category.startswith("images."):
# Image data is stored as numpy uint8 array
if isinstance(data1, np.ndarray) and isinstance(data2, np.ndarray):
return self.compare_image_data(key, data1, data2)
else:
self.warnings.append(f"[{key.decode()}] Expected numpy array for image data")
return False
else:
# Unknown category, try generic comparison
self.warnings.append(f"[{key.decode()}] Unknown category '{category}', using binary comparison")
return val1 == val2
def compare(self) -> bool:
"""Execute full comparison."""
print(f"Comparing directories:")
print(f" Dir1: {self.dir1}")
print(f" Dir2: {self.dir2}\n")
# Load data
meta1, txn1, env1 = self.load_directory(self.dir1)
meta2, txn2, env2 = self.load_directory(self.dir2)
if meta1 is None or meta2 is None:
print("Aborting due to loading errors.")
return False
print("Successfully loaded data from both directories.\n")
# Compare metadata
print("Comparing metadata...")
self.compare_metadata(meta1, meta2)
# Get key categories
key_cat1 = self.get_key_categories(meta1)
key_cat2 = self.get_key_categories(meta2)
keys1 = set(key_cat1.keys())
keys2 = set(key_cat2.keys())
# Check key sets
if keys1 != keys2:
missing_in_2 = sorted([k.decode() for k in keys1 - keys2])
missing_in_1 = sorted([k.decode() for k in keys2 - keys1])
if missing_in_2:
self.mismatches.append(f"Keys missing in dir2: {missing_in_2[:10]}")
if missing_in_1:
self.mismatches.append(f"Keys missing in dir1: {missing_in_1[:10]}")
# Compare common keys
common_keys = sorted(list(keys1.intersection(keys2)))
print(f"Comparing {len(common_keys)} common keys...\n")
for i, key in enumerate(common_keys):
if i % 100 == 0 and i > 0:
print(f"Progress: {i}/{len(common_keys)} keys compared...")
category = key_cat1.get(key, "unknown")
val1 = txn1.get(key)
val2 = txn2.get(key)
self.compare_value(key, category, val1, val2)
if self.image_psnr_values:
avg_psnr = sum(self.image_psnr_values) / len(self.image_psnr_values)
print(
f"\nImage PSNR average over {len(self.image_psnr_values)} images: "
f"{avg_psnr:.2f} dB (threshold {self.image_psnr_threshold:.2f} dB)"
)
if avg_psnr < self.image_psnr_threshold:
self.mismatches.append(
f"Average image PSNR {avg_psnr:.2f} dB below threshold "
f"{self.image_psnr_threshold:.2f} dB"
)
else:
print("\nNo image entries found for PSNR calculation.")
# Cleanup
env1.close()
env2.close()
# Print results
print("\n" + "=" * 80)
print("COMPARISON RESULTS")
print("=" * 80)
if self.warnings:
print(f"\nWarnings ({len(self.warnings)}):")
for warning in self.warnings[:20]:
print(f" - {warning}")
if len(self.warnings) > 20:
print(f" ... and {len(self.warnings) - 20} more warnings")
if self.mismatches:
print(f"\nMismatches found ({len(self.mismatches)}):")
for mismatch in self.mismatches[:30]:
print(f" - {mismatch}")
if len(self.mismatches) > 30:
print(f" ... and {len(self.mismatches) - 30} more mismatches")
print("\n❌ RESULT: Directories are DIFFERENT")
return False
else:
print("\n✅ RESULT: Directories are IDENTICAL")
return True
def main():
"""Main entry point."""
parser = argparse.ArgumentParser(
description="Compare two Simbox task output directories.",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
%(prog)s --dir1 output/run1 --dir2 output/run2
%(prog)s --dir1 output/run1 --dir2 output/run2 --tolerance 1e-5
%(prog)s --dir1 output/run1 --dir2 output/run2 --image-psnr 40.0
""",
)
parser.add_argument("--dir1", type=str, required=True, help="Path to the first output directory")
parser.add_argument("--dir2", type=str, required=True, help="Path to the second output directory")
parser.add_argument(
"--tolerance",
type=float,
default=1e-6,
help="Numerical tolerance for floating point comparisons (default: 1e-6)",
)
parser.add_argument(
"--image-psnr",
type=float,
default=37.0,
help="PSNR threshold (dB) for considering images as matching (default: 37.0)",
)
args = parser.parse_args()
comparator = SimboxComparator(args.dir1, args.dir2, tolerance=args.tolerance, image_psnr_threshold=args.image_psnr)
success = comparator.compare()
exit(0 if success else 1)
if __name__ == "__main__":
main()