记录动作以及回放record_demo,replay_demo
This commit is contained in:
169
scripts/tools/test/test_cosmos_prompt_gen.py
Normal file
169
scripts/tools/test/test_cosmos_prompt_gen.py
Normal file
@@ -0,0 +1,169 @@
|
||||
# Copyright (c) 2024-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
|
||||
# All rights reserved.
|
||||
#
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
"""Test cases for Cosmos prompt generation script."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from scripts.tools.cosmos.cosmos_prompt_gen import generate_prompt, main
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def temp_templates_file():
|
||||
"""Create temporary templates file."""
|
||||
temp_file = tempfile.NamedTemporaryFile(suffix=".json", delete=False) # noqa: SIM115
|
||||
|
||||
# Create test templates
|
||||
test_templates = {
|
||||
"lighting": ["with bright lighting", "with dim lighting", "with natural lighting"],
|
||||
"color": ["in warm colors", "in cool colors", "in vibrant colors"],
|
||||
"style": ["in a realistic style", "in an artistic style", "in a minimalist style"],
|
||||
"empty_section": [], # Test empty section
|
||||
"invalid_section": "not a list", # Test invalid section
|
||||
}
|
||||
|
||||
# Write templates to file
|
||||
with open(temp_file.name, "w") as f:
|
||||
json.dump(test_templates, f)
|
||||
|
||||
yield temp_file.name
|
||||
# Cleanup
|
||||
os.remove(temp_file.name)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_output_file():
|
||||
"""Create temporary output file."""
|
||||
temp_file = tempfile.NamedTemporaryFile(suffix=".txt", delete=False) # noqa: SIM115
|
||||
yield temp_file.name
|
||||
# Cleanup
|
||||
os.remove(temp_file.name)
|
||||
|
||||
|
||||
class TestCosmosPromptGen:
|
||||
"""Test cases for Cosmos prompt generation functionality."""
|
||||
|
||||
def test_generate_prompt_valid_templates(self, temp_templates_file):
|
||||
"""Test generating a prompt with valid templates."""
|
||||
prompt = generate_prompt(temp_templates_file)
|
||||
|
||||
# Check that prompt is a string
|
||||
assert isinstance(prompt, str)
|
||||
|
||||
# Check that prompt contains at least one word
|
||||
assert len(prompt.split()) > 0
|
||||
|
||||
# Check that prompt contains valid sections
|
||||
valid_sections = ["lighting", "color", "style"]
|
||||
found_sections = [section for section in valid_sections if section in prompt.lower()]
|
||||
assert len(found_sections) > 0
|
||||
|
||||
def test_generate_prompt_invalid_file(self):
|
||||
"""Test generating a prompt with invalid file path."""
|
||||
with pytest.raises(FileNotFoundError):
|
||||
generate_prompt("nonexistent_file.json")
|
||||
|
||||
def test_generate_prompt_invalid_json(self):
|
||||
"""Test generating a prompt with invalid JSON file."""
|
||||
# Create a temporary file with invalid JSON
|
||||
with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as temp_file:
|
||||
temp_file.write(b"invalid json content")
|
||||
temp_file.flush()
|
||||
|
||||
try:
|
||||
with pytest.raises(ValueError):
|
||||
generate_prompt(temp_file.name)
|
||||
finally:
|
||||
os.remove(temp_file.name)
|
||||
|
||||
def test_main_function_single_prompt(self, temp_templates_file, temp_output_file):
|
||||
"""Test main function with single prompt generation."""
|
||||
# Mock command line arguments
|
||||
import sys
|
||||
|
||||
original_argv = sys.argv
|
||||
sys.argv = [
|
||||
"cosmos_prompt_gen.py",
|
||||
"--templates_path",
|
||||
temp_templates_file,
|
||||
"--num_prompts",
|
||||
"1",
|
||||
"--output_path",
|
||||
temp_output_file,
|
||||
]
|
||||
|
||||
try:
|
||||
main()
|
||||
|
||||
# Check if output file was created
|
||||
assert os.path.exists(temp_output_file)
|
||||
|
||||
# Check content of output file
|
||||
with open(temp_output_file) as f:
|
||||
content = f.read().strip()
|
||||
assert len(content) > 0
|
||||
assert len(content.split("\n")) == 1
|
||||
finally:
|
||||
# Restore original argv
|
||||
sys.argv = original_argv
|
||||
|
||||
def test_main_function_multiple_prompts(self, temp_templates_file, temp_output_file):
|
||||
"""Test main function with multiple prompt generation."""
|
||||
# Mock command line arguments
|
||||
import sys
|
||||
|
||||
original_argv = sys.argv
|
||||
sys.argv = [
|
||||
"cosmos_prompt_gen.py",
|
||||
"--templates_path",
|
||||
temp_templates_file,
|
||||
"--num_prompts",
|
||||
"3",
|
||||
"--output_path",
|
||||
temp_output_file,
|
||||
]
|
||||
|
||||
try:
|
||||
main()
|
||||
|
||||
# Check if output file was created
|
||||
assert os.path.exists(temp_output_file)
|
||||
|
||||
# Check content of output file
|
||||
with open(temp_output_file) as f:
|
||||
content = f.read().strip()
|
||||
assert len(content) > 0
|
||||
assert len(content.split("\n")) == 3
|
||||
|
||||
# Check that each line is a valid prompt
|
||||
for line in content.split("\n"):
|
||||
assert len(line) > 0
|
||||
finally:
|
||||
# Restore original argv
|
||||
sys.argv = original_argv
|
||||
|
||||
def test_main_function_default_output(self, temp_templates_file):
|
||||
"""Test main function with default output path."""
|
||||
# Mock command line arguments
|
||||
import sys
|
||||
|
||||
original_argv = sys.argv
|
||||
sys.argv = ["cosmos_prompt_gen.py", "--templates_path", temp_templates_file, "--num_prompts", "1"]
|
||||
|
||||
try:
|
||||
main()
|
||||
|
||||
# Check if default output file was created
|
||||
assert os.path.exists("prompts.txt")
|
||||
|
||||
# Clean up default output file
|
||||
os.remove("prompts.txt")
|
||||
finally:
|
||||
# Restore original argv
|
||||
sys.argv = original_argv
|
||||
173
scripts/tools/test/test_hdf5_to_mp4.py
Normal file
173
scripts/tools/test/test_hdf5_to_mp4.py
Normal file
@@ -0,0 +1,173 @@
|
||||
# Copyright (c) 2024-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
|
||||
# All rights reserved.
|
||||
#
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
"""Test cases for HDF5 to MP4 conversion script."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from scripts.tools.hdf5_to_mp4 import get_num_demos, main, write_demo_to_mp4
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def temp_hdf5_file():
|
||||
"""Create temporary HDF5 file with test data."""
|
||||
temp_file = tempfile.NamedTemporaryFile(suffix=".h5", delete=False) # noqa: SIM115
|
||||
with h5py.File(temp_file.name, "w") as h5f:
|
||||
# Create test data structure
|
||||
for demo_id in range(2): # Create 2 demos
|
||||
demo_group = h5f.create_group(f"data/demo_{demo_id}/obs")
|
||||
|
||||
# Create RGB frames (2 frames per demo)
|
||||
rgb_data = np.random.randint(0, 255, (2, 704, 1280, 3), dtype=np.uint8)
|
||||
demo_group.create_dataset("table_cam", data=rgb_data)
|
||||
|
||||
# Create segmentation frames
|
||||
seg_data = np.random.randint(0, 255, (2, 704, 1280, 4), dtype=np.uint8)
|
||||
demo_group.create_dataset("table_cam_segmentation", data=seg_data)
|
||||
|
||||
# Create normal maps
|
||||
normals_data = np.random.rand(2, 704, 1280, 3).astype(np.float32)
|
||||
demo_group.create_dataset("table_cam_normals", data=normals_data)
|
||||
|
||||
# Create depth maps
|
||||
depth_data = np.random.rand(2, 704, 1280, 1).astype(np.float32)
|
||||
demo_group.create_dataset("table_cam_depth", data=depth_data)
|
||||
|
||||
yield temp_file.name
|
||||
# Cleanup
|
||||
os.remove(temp_file.name)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_output_dir():
|
||||
"""Create temporary output directory."""
|
||||
temp_dir = tempfile.mkdtemp() # noqa: SIM115
|
||||
yield temp_dir
|
||||
# Cleanup
|
||||
for file in os.listdir(temp_dir):
|
||||
os.remove(os.path.join(temp_dir, file))
|
||||
os.rmdir(temp_dir)
|
||||
|
||||
|
||||
class TestHDF5ToMP4:
|
||||
"""Test cases for HDF5 to MP4 conversion functionality."""
|
||||
|
||||
def test_get_num_demos(self, temp_hdf5_file):
|
||||
"""Test the get_num_demos function."""
|
||||
num_demos = get_num_demos(temp_hdf5_file)
|
||||
assert num_demos == 2
|
||||
|
||||
def test_write_demo_to_mp4_rgb(self, temp_hdf5_file, temp_output_dir):
|
||||
"""Test writing RGB frames to MP4."""
|
||||
write_demo_to_mp4(temp_hdf5_file, 0, "data/demo_0/obs", "table_cam", temp_output_dir, 704, 1280)
|
||||
|
||||
output_file = os.path.join(temp_output_dir, "demo_0_table_cam.mp4")
|
||||
assert os.path.exists(output_file)
|
||||
assert os.path.getsize(output_file) > 0
|
||||
|
||||
def test_write_demo_to_mp4_segmentation(self, temp_hdf5_file, temp_output_dir):
|
||||
"""Test writing segmentation frames to MP4."""
|
||||
write_demo_to_mp4(temp_hdf5_file, 0, "data/demo_0/obs", "table_cam_segmentation", temp_output_dir, 704, 1280)
|
||||
|
||||
output_file = os.path.join(temp_output_dir, "demo_0_table_cam_segmentation.mp4")
|
||||
assert os.path.exists(output_file)
|
||||
assert os.path.getsize(output_file) > 0
|
||||
|
||||
def test_write_demo_to_mp4_normals(self, temp_hdf5_file, temp_output_dir):
|
||||
"""Test writing normal maps to MP4."""
|
||||
write_demo_to_mp4(temp_hdf5_file, 0, "data/demo_0/obs", "table_cam_normals", temp_output_dir, 704, 1280)
|
||||
|
||||
output_file = os.path.join(temp_output_dir, "demo_0_table_cam_normals.mp4")
|
||||
assert os.path.exists(output_file)
|
||||
assert os.path.getsize(output_file) > 0
|
||||
|
||||
def test_write_demo_to_mp4_shaded_segmentation(self, temp_hdf5_file, temp_output_dir):
|
||||
"""Test writing shaded_segmentation frames to MP4."""
|
||||
write_demo_to_mp4(
|
||||
temp_hdf5_file,
|
||||
0,
|
||||
"data/demo_0/obs",
|
||||
"table_cam_shaded_segmentation",
|
||||
temp_output_dir,
|
||||
704,
|
||||
1280,
|
||||
)
|
||||
|
||||
output_file = os.path.join(temp_output_dir, "demo_0_table_cam_shaded_segmentation.mp4")
|
||||
assert os.path.exists(output_file)
|
||||
assert os.path.getsize(output_file) > 0
|
||||
|
||||
def test_write_demo_to_mp4_depth(self, temp_hdf5_file, temp_output_dir):
|
||||
"""Test writing depth maps to MP4."""
|
||||
write_demo_to_mp4(temp_hdf5_file, 0, "data/demo_0/obs", "table_cam_depth", temp_output_dir, 704, 1280)
|
||||
|
||||
output_file = os.path.join(temp_output_dir, "demo_0_table_cam_depth.mp4")
|
||||
assert os.path.exists(output_file)
|
||||
assert os.path.getsize(output_file) > 0
|
||||
|
||||
def test_write_demo_to_mp4_invalid_demo(self, temp_hdf5_file, temp_output_dir):
|
||||
"""Test writing with invalid demo ID."""
|
||||
with pytest.raises(KeyError):
|
||||
write_demo_to_mp4(
|
||||
temp_hdf5_file,
|
||||
999, # Invalid demo ID
|
||||
"data/demo_999/obs",
|
||||
"table_cam",
|
||||
temp_output_dir,
|
||||
704,
|
||||
1280,
|
||||
)
|
||||
|
||||
def test_write_demo_to_mp4_invalid_key(self, temp_hdf5_file, temp_output_dir):
|
||||
"""Test writing with invalid input key."""
|
||||
with pytest.raises(KeyError):
|
||||
write_demo_to_mp4(temp_hdf5_file, 0, "data/demo_0/obs", "invalid_key", temp_output_dir, 704, 1280)
|
||||
|
||||
def test_main_function(self, temp_hdf5_file, temp_output_dir):
|
||||
"""Test the main function."""
|
||||
# Mock command line arguments
|
||||
import sys
|
||||
|
||||
original_argv = sys.argv
|
||||
sys.argv = [
|
||||
"hdf5_to_mp4.py",
|
||||
"--input_file",
|
||||
temp_hdf5_file,
|
||||
"--output_dir",
|
||||
temp_output_dir,
|
||||
"--input_keys",
|
||||
"table_cam",
|
||||
"table_cam_segmentation",
|
||||
"--video_height",
|
||||
"704",
|
||||
"--video_width",
|
||||
"1280",
|
||||
"--framerate",
|
||||
"30",
|
||||
]
|
||||
|
||||
try:
|
||||
main()
|
||||
|
||||
# Check if output files were created
|
||||
expected_files = [
|
||||
"demo_0_table_cam.mp4",
|
||||
"demo_0_table_cam_segmentation.mp4",
|
||||
"demo_1_table_cam.mp4",
|
||||
"demo_1_table_cam_segmentation.mp4",
|
||||
]
|
||||
|
||||
for file in expected_files:
|
||||
output_file = os.path.join(temp_output_dir, file)
|
||||
assert os.path.exists(output_file)
|
||||
assert os.path.getsize(output_file) > 0
|
||||
finally:
|
||||
# Restore original argv
|
||||
sys.argv = original_argv
|
||||
181
scripts/tools/test/test_mp4_to_hdf5.py
Normal file
181
scripts/tools/test/test_mp4_to_hdf5.py
Normal file
@@ -0,0 +1,181 @@
|
||||
# Copyright (c) 2024-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
|
||||
# All rights reserved.
|
||||
#
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
"""Test cases for MP4 to HDF5 conversion script."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import cv2
|
||||
import h5py
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from scripts.tools.mp4_to_hdf5 import get_frames_from_mp4, main, process_video_and_demo
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def temp_hdf5_file():
|
||||
"""Create temporary HDF5 file with test data."""
|
||||
temp_file = tempfile.NamedTemporaryFile(suffix=".h5", delete=False) # noqa: SIM115
|
||||
with h5py.File(temp_file.name, "w") as h5f:
|
||||
# Create test data structure for 2 demos
|
||||
for demo_id in range(2):
|
||||
demo_group = h5f.create_group(f"data/demo_{demo_id}")
|
||||
obs_group = demo_group.create_group("obs")
|
||||
|
||||
# Create actions data
|
||||
actions_data = np.random.rand(10, 7).astype(np.float32)
|
||||
demo_group.create_dataset("actions", data=actions_data)
|
||||
|
||||
# Create robot state data
|
||||
eef_pos_data = np.random.rand(10, 3).astype(np.float32)
|
||||
eef_quat_data = np.random.rand(10, 4).astype(np.float32)
|
||||
gripper_pos_data = np.random.rand(10, 1).astype(np.float32)
|
||||
obs_group.create_dataset("eef_pos", data=eef_pos_data)
|
||||
obs_group.create_dataset("eef_quat", data=eef_quat_data)
|
||||
obs_group.create_dataset("gripper_pos", data=gripper_pos_data)
|
||||
|
||||
# Create camera data
|
||||
table_cam_data = np.random.randint(0, 255, (10, 704, 1280, 3), dtype=np.uint8)
|
||||
wrist_cam_data = np.random.randint(0, 255, (10, 704, 1280, 3), dtype=np.uint8)
|
||||
obs_group.create_dataset("table_cam", data=table_cam_data)
|
||||
obs_group.create_dataset("wrist_cam", data=wrist_cam_data)
|
||||
|
||||
# Set attributes
|
||||
demo_group.attrs["num_samples"] = 10
|
||||
|
||||
yield temp_file.name
|
||||
# Cleanup
|
||||
os.remove(temp_file.name)
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def temp_videos_dir():
|
||||
"""Create temporary MP4 files."""
|
||||
temp_dir = tempfile.mkdtemp() # noqa: SIM115
|
||||
video_paths = []
|
||||
|
||||
for demo_id in range(2):
|
||||
video_path = os.path.join(temp_dir, f"demo_{demo_id}_table_cam.mp4")
|
||||
video_paths.append(video_path)
|
||||
|
||||
# Create a test video
|
||||
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
||||
video = cv2.VideoWriter(video_path, fourcc, 30, (1280, 704))
|
||||
|
||||
# Write some random frames
|
||||
for _ in range(10):
|
||||
frame = np.random.randint(0, 255, (704, 1280, 3), dtype=np.uint8)
|
||||
video.write(frame)
|
||||
video.release()
|
||||
|
||||
yield temp_dir, video_paths
|
||||
|
||||
# Cleanup
|
||||
for video_path in video_paths:
|
||||
os.remove(video_path)
|
||||
os.rmdir(temp_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_output_file():
|
||||
"""Create temporary output file."""
|
||||
temp_file = tempfile.NamedTemporaryFile(suffix=".h5", delete=False) # noqa: SIM115
|
||||
yield temp_file.name
|
||||
# Cleanup
|
||||
os.remove(temp_file.name)
|
||||
|
||||
|
||||
class TestMP4ToHDF5:
|
||||
"""Test cases for MP4 to HDF5 conversion functionality."""
|
||||
|
||||
def test_get_frames_from_mp4(self, temp_videos_dir):
|
||||
"""Test extracting frames from MP4 video."""
|
||||
_, video_paths = temp_videos_dir
|
||||
frames = get_frames_from_mp4(video_paths[0])
|
||||
|
||||
# Check frame properties
|
||||
assert frames.shape[0] == 10 # Number of frames
|
||||
assert frames.shape[1:] == (704, 1280, 3) # Frame dimensions
|
||||
assert frames.dtype == np.uint8 # Data type
|
||||
|
||||
def test_get_frames_from_mp4_resize(self, temp_videos_dir):
|
||||
"""Test extracting frames with resizing."""
|
||||
_, video_paths = temp_videos_dir
|
||||
target_height, target_width = 352, 640
|
||||
frames = get_frames_from_mp4(video_paths[0], target_height, target_width)
|
||||
|
||||
# Check resized frame properties
|
||||
assert frames.shape[0] == 10 # Number of frames
|
||||
assert frames.shape[1:] == (target_height, target_width, 3) # Resized dimensions
|
||||
assert frames.dtype == np.uint8 # Data type
|
||||
|
||||
def test_process_video_and_demo(self, temp_hdf5_file, temp_videos_dir, temp_output_file):
|
||||
"""Test processing a single video and creating a new demo."""
|
||||
_, video_paths = temp_videos_dir
|
||||
with h5py.File(temp_hdf5_file, "r") as f_in, h5py.File(temp_output_file, "w") as f_out:
|
||||
process_video_and_demo(f_in, f_out, video_paths[0], 0, 2)
|
||||
|
||||
# Check if new demo was created with correct data
|
||||
assert "data/demo_2" in f_out
|
||||
assert "data/demo_2/actions" in f_out
|
||||
assert "data/demo_2/obs/eef_pos" in f_out
|
||||
assert "data/demo_2/obs/eef_quat" in f_out
|
||||
assert "data/demo_2/obs/gripper_pos" in f_out
|
||||
assert "data/demo_2/obs/table_cam" in f_out
|
||||
assert "data/demo_2/obs/wrist_cam" in f_out
|
||||
|
||||
# Check data shapes
|
||||
assert f_out["data/demo_2/actions"].shape == (10, 7)
|
||||
assert f_out["data/demo_2/obs/eef_pos"].shape == (10, 3)
|
||||
assert f_out["data/demo_2/obs/eef_quat"].shape == (10, 4)
|
||||
assert f_out["data/demo_2/obs/gripper_pos"].shape == (10, 1)
|
||||
assert f_out["data/demo_2/obs/table_cam"].shape == (10, 704, 1280, 3)
|
||||
assert f_out["data/demo_2/obs/wrist_cam"].shape == (10, 704, 1280, 3)
|
||||
|
||||
# Check attributes
|
||||
assert f_out["data/demo_2"].attrs["num_samples"] == 10
|
||||
|
||||
def test_main_function(self, temp_hdf5_file, temp_videos_dir, temp_output_file):
|
||||
"""Test the main function."""
|
||||
# Mock command line arguments
|
||||
import sys
|
||||
|
||||
original_argv = sys.argv
|
||||
sys.argv = [
|
||||
"mp4_to_hdf5.py",
|
||||
"--input_file",
|
||||
temp_hdf5_file,
|
||||
"--videos_dir",
|
||||
temp_videos_dir[0],
|
||||
"--output_file",
|
||||
temp_output_file,
|
||||
]
|
||||
|
||||
try:
|
||||
main()
|
||||
|
||||
# Check if output file was created with correct data
|
||||
with h5py.File(temp_output_file, "r") as f:
|
||||
# Check if original demos were copied
|
||||
assert "data/demo_0" in f
|
||||
assert "data/demo_1" in f
|
||||
|
||||
# Check if new demos were created
|
||||
assert "data/demo_2" in f
|
||||
assert "data/demo_3" in f
|
||||
|
||||
# Check data in new demos
|
||||
for demo_id in [2, 3]:
|
||||
assert f"data/demo_{demo_id}/actions" in f
|
||||
assert f"data/demo_{demo_id}/obs/eef_pos" in f
|
||||
assert f"data/demo_{demo_id}/obs/eef_quat" in f
|
||||
assert f"data/demo_{demo_id}/obs/gripper_pos" in f
|
||||
assert f"data/demo_{demo_id}/obs/table_cam" in f
|
||||
assert f"data/demo_{demo_id}/obs/wrist_cam" in f
|
||||
finally:
|
||||
# Restore original argv
|
||||
sys.argv = original_argv
|
||||
Reference in New Issue
Block a user