[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
@@ -129,7 +129,12 @@ def patch_builtins_input(monkeypatch):
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption("--seed", action="store", default="42", help="Set random seed for reproducibility")
|
||||
parser.addoption(
|
||||
"--seed",
|
||||
action="store",
|
||||
default="42",
|
||||
help="Set random seed for reproducibility",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
||||
30
tests/fixtures/constants.py
vendored
30
tests/fixtures/constants.py
vendored
@@ -7,17 +7,39 @@ DUMMY_MOTOR_FEATURES = {
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
||||
"names": [
|
||||
"shoulder_pan",
|
||||
"shoulder_lift",
|
||||
"elbow_flex",
|
||||
"wrist_flex",
|
||||
"wrist_roll",
|
||||
"gripper",
|
||||
],
|
||||
},
|
||||
"state": {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
||||
"names": [
|
||||
"shoulder_pan",
|
||||
"shoulder_lift",
|
||||
"elbow_flex",
|
||||
"wrist_flex",
|
||||
"wrist_roll",
|
||||
"gripper",
|
||||
],
|
||||
},
|
||||
}
|
||||
DUMMY_CAMERA_FEATURES = {
|
||||
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
||||
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
||||
"laptop": {
|
||||
"shape": (480, 640, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
},
|
||||
"phone": {
|
||||
"shape": (480, 640, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
},
|
||||
}
|
||||
DEFAULT_FPS = 30
|
||||
DUMMY_VIDEO_INFO = {
|
||||
|
||||
94
tests/fixtures/dataset_factories.py
vendored
94
tests/fixtures/dataset_factories.py
vendored
@@ -8,7 +8,11 @@ import PIL.Image
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
CODEBASE_VERSION,
|
||||
LeRobotDataset,
|
||||
LeRobotDatasetMetadata,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_FEATURES,
|
||||
@@ -35,7 +39,9 @@ def get_task_index(task_dicts: dict, task: str) -> int:
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def img_tensor_factory():
|
||||
def _create_img_tensor(height=100, width=100, channels=3, dtype=torch.float32) -> torch.Tensor:
|
||||
def _create_img_tensor(
|
||||
height=100, width=100, channels=3, dtype=torch.float32
|
||||
) -> torch.Tensor:
|
||||
return torch.rand((channels, height, width), dtype=dtype)
|
||||
|
||||
return _create_img_tensor
|
||||
@@ -43,10 +49,14 @@ def img_tensor_factory():
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def img_array_factory():
|
||||
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray:
|
||||
def _create_img_array(
|
||||
height=100, width=100, channels=3, dtype=np.uint8
|
||||
) -> np.ndarray:
|
||||
if np.issubdtype(dtype, np.unsignedinteger):
|
||||
# Int array in [0, 255] range
|
||||
img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype)
|
||||
img_array = np.random.randint(
|
||||
0, 256, size=(height, width, channels), dtype=dtype
|
||||
)
|
||||
elif np.issubdtype(dtype, np.floating):
|
||||
# Float array in [0, 1] range
|
||||
img_array = np.random.rand(height, width, channels).astype(dtype)
|
||||
@@ -75,10 +85,13 @@ def features_factory():
|
||||
) -> dict:
|
||||
if use_videos:
|
||||
camera_ft = {
|
||||
key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} for key, ft in camera_features.items()
|
||||
key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO}
|
||||
for key, ft in camera_features.items()
|
||||
}
|
||||
else:
|
||||
camera_ft = {key: {"dtype": "image", **ft} for key, ft in camera_features.items()}
|
||||
camera_ft = {
|
||||
key: {"dtype": "image", **ft} for key, ft in camera_features.items()
|
||||
}
|
||||
return {
|
||||
**motor_features,
|
||||
**camera_ft,
|
||||
@@ -177,7 +190,9 @@ def episodes_factory(tasks_factory):
|
||||
if total_episodes <= 0 or total_frames <= 0:
|
||||
raise ValueError("num_episodes and total_length must be positive integers.")
|
||||
if total_frames < total_episodes:
|
||||
raise ValueError("total_length must be greater than or equal to num_episodes.")
|
||||
raise ValueError(
|
||||
"total_length must be greater than or equal to num_episodes."
|
||||
)
|
||||
|
||||
if not tasks:
|
||||
min_tasks = 2 if multi_task else 1
|
||||
@@ -185,10 +200,14 @@ def episodes_factory(tasks_factory):
|
||||
tasks = tasks_factory(total_tasks)
|
||||
|
||||
if total_episodes < len(tasks) and not multi_task:
|
||||
raise ValueError("The number of tasks should be less than the number of episodes.")
|
||||
raise ValueError(
|
||||
"The number of tasks should be less than the number of episodes."
|
||||
)
|
||||
|
||||
# Generate random lengths that sum up to total_length
|
||||
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
|
||||
lengths = np.random.multinomial(
|
||||
total_frames, [1 / total_episodes] * total_episodes
|
||||
).tolist()
|
||||
|
||||
tasks_list = [task_dict["task"] for task_dict in tasks]
|
||||
num_tasks_available = len(tasks_list)
|
||||
@@ -196,9 +215,13 @@ def episodes_factory(tasks_factory):
|
||||
episodes_list = []
|
||||
remaining_tasks = tasks_list.copy()
|
||||
for ep_idx in range(total_episodes):
|
||||
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
|
||||
num_tasks_in_episode = (
|
||||
random.randint(1, min(3, num_tasks_available)) if multi_task else 1
|
||||
)
|
||||
tasks_to_sample = remaining_tasks if remaining_tasks else tasks_list
|
||||
episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample)))
|
||||
episode_tasks = random.sample(
|
||||
tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample))
|
||||
)
|
||||
if remaining_tasks:
|
||||
for task in episode_tasks:
|
||||
remaining_tasks.remove(task)
|
||||
@@ -217,7 +240,9 @@ def episodes_factory(tasks_factory):
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
|
||||
def hf_dataset_factory(
|
||||
features_factory, tasks_factory, episodes_factory, img_array_factory
|
||||
):
|
||||
def _create_hf_dataset(
|
||||
features: dict | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
@@ -236,13 +261,22 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
||||
episode_index_col = np.array([], dtype=np.int64)
|
||||
task_index = np.array([], dtype=np.int64)
|
||||
for ep_dict in episodes:
|
||||
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
|
||||
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
|
||||
timestamp_col = np.concatenate(
|
||||
(timestamp_col, np.arange(ep_dict["length"]) / fps)
|
||||
)
|
||||
frame_index_col = np.concatenate(
|
||||
(frame_index_col, np.arange(ep_dict["length"], dtype=int))
|
||||
)
|
||||
episode_index_col = np.concatenate(
|
||||
(episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int))
|
||||
(
|
||||
episode_index_col,
|
||||
np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int),
|
||||
)
|
||||
)
|
||||
ep_task_index = get_task_index(tasks, ep_dict["tasks"][0])
|
||||
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)))
|
||||
task_index = np.concatenate(
|
||||
(task_index, np.full(ep_dict["length"], ep_task_index, dtype=int))
|
||||
)
|
||||
|
||||
index_col = np.arange(len(episode_index_col))
|
||||
|
||||
@@ -254,7 +288,9 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
||||
for _ in range(len(index_col))
|
||||
]
|
||||
elif ft["shape"][0] > 1 and ft["dtype"] != "video":
|
||||
robot_cols[key] = np.random.random((len(index_col), ft["shape"][0])).astype(ft["dtype"])
|
||||
robot_cols[key] = np.random.random(
|
||||
(len(index_col), ft["shape"][0])
|
||||
).astype(ft["dtype"])
|
||||
|
||||
hf_features = get_hf_features_from_features(features)
|
||||
dataset = datasets.Dataset.from_dict(
|
||||
@@ -299,7 +335,9 @@ def lerobot_dataset_metadata_factory(
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if not episodes:
|
||||
episodes = episodes_factory(
|
||||
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
|
||||
total_episodes=info["total_episodes"],
|
||||
total_frames=info["total_frames"],
|
||||
tasks=tasks,
|
||||
)
|
||||
|
||||
mock_snapshot_download = mock_snapshot_download_factory(
|
||||
@@ -316,10 +354,14 @@ def lerobot_dataset_metadata_factory(
|
||||
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
|
||||
) as mock_snapshot_download_patch,
|
||||
):
|
||||
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version: version
|
||||
mock_get_hub_safe_version_patch.side_effect = (
|
||||
lambda repo_id, version: version
|
||||
)
|
||||
mock_snapshot_download_patch.side_effect = mock_snapshot_download
|
||||
|
||||
return LeRobotDatasetMetadata(repo_id=repo_id, root=root, local_files_only=local_files_only)
|
||||
return LeRobotDatasetMetadata(
|
||||
repo_id=repo_id, root=root, local_files_only=local_files_only
|
||||
)
|
||||
|
||||
return _create_lerobot_dataset_metadata
|
||||
|
||||
@@ -350,7 +392,9 @@ def lerobot_dataset_factory(
|
||||
) -> LeRobotDataset:
|
||||
if not info:
|
||||
info = info_factory(
|
||||
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
total_tasks=total_tasks,
|
||||
)
|
||||
if not stats:
|
||||
stats = stats_factory(features=info["features"])
|
||||
@@ -364,7 +408,9 @@ def lerobot_dataset_factory(
|
||||
multi_task=multi_task,
|
||||
)
|
||||
if not hf_dataset:
|
||||
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episode_dicts, fps=info["fps"])
|
||||
hf_dataset = hf_dataset_factory(
|
||||
tasks=tasks, episodes=episode_dicts, fps=info["fps"]
|
||||
)
|
||||
|
||||
mock_snapshot_download = mock_snapshot_download_factory(
|
||||
info=info,
|
||||
@@ -383,7 +429,9 @@ def lerobot_dataset_factory(
|
||||
local_files_only=kwargs.get("local_files_only", False),
|
||||
)
|
||||
with (
|
||||
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata"
|
||||
) as mock_metadata_patch,
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
|
||||
) as mock_snapshot_download_patch,
|
||||
|
||||
12
tests/fixtures/files.py
vendored
12
tests/fixtures/files.py
vendored
@@ -7,7 +7,12 @@ import pyarrow.compute as pc
|
||||
import pyarrow.parquet as pq
|
||||
import pytest
|
||||
|
||||
from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH
|
||||
from lerobot.common.datasets.utils import (
|
||||
EPISODES_PATH,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
TASKS_PATH,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@@ -69,7 +74,10 @@ def episode_path(episodes_factory):
|
||||
@pytest.fixture(scope="session")
|
||||
def single_episode_parquet_path(hf_dataset_factory, info_factory):
|
||||
def _create_single_episode_parquet(
|
||||
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||
dir: Path,
|
||||
ep_idx: int = 0,
|
||||
hf_dataset: datasets.Dataset | None = None,
|
||||
info: dict | None = None,
|
||||
) -> Path:
|
||||
if not info:
|
||||
info = info_factory()
|
||||
|
||||
31
tests/fixtures/hub.py
vendored
31
tests/fixtures/hub.py
vendored
@@ -4,7 +4,12 @@ import datasets
|
||||
import pytest
|
||||
from huggingface_hub.utils import filter_repo_objects
|
||||
|
||||
from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH
|
||||
from lerobot.common.datasets.utils import (
|
||||
EPISODES_PATH,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
TASKS_PATH,
|
||||
)
|
||||
from tests.fixtures.constants import LEROBOT_TEST_DIR
|
||||
|
||||
|
||||
@@ -41,15 +46,21 @@ def mock_snapshot_download_factory(
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if not episodes:
|
||||
episodes = episodes_factory(
|
||||
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
|
||||
total_episodes=info["total_episodes"],
|
||||
total_frames=info["total_frames"],
|
||||
tasks=tasks,
|
||||
)
|
||||
if not hf_dataset:
|
||||
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
|
||||
hf_dataset = hf_dataset_factory(
|
||||
tasks=tasks, episodes=episodes, fps=info["fps"]
|
||||
)
|
||||
|
||||
def _extract_episode_index_from_path(fpath: str) -> int:
|
||||
path = Path(fpath)
|
||||
if path.suffix == ".parquet" and path.stem.startswith("episode_"):
|
||||
episode_index = int(path.stem[len("episode_") :]) # 'episode_000000' -> 0
|
||||
episode_index = int(
|
||||
path.stem[len("episode_") :]
|
||||
) # 'episode_000000' -> 0
|
||||
return episode_index
|
||||
else:
|
||||
return None
|
||||
@@ -74,12 +85,16 @@ def mock_snapshot_download_factory(
|
||||
for episode_dict in episodes:
|
||||
ep_idx = episode_dict["episode_index"]
|
||||
ep_chunk = ep_idx // info["chunks_size"]
|
||||
data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
data_path = info["data_path"].format(
|
||||
episode_chunk=ep_chunk, episode_index=ep_idx
|
||||
)
|
||||
data_files.append(data_path)
|
||||
all_files.extend(data_files)
|
||||
|
||||
allowed_files = filter_repo_objects(
|
||||
all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
|
||||
all_files,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
|
||||
# Create allowed files
|
||||
@@ -87,7 +102,9 @@ def mock_snapshot_download_factory(
|
||||
if rel_path.startswith("data/"):
|
||||
episode_index = _extract_episode_index_from_path(rel_path)
|
||||
if episode_index is not None:
|
||||
_ = single_episode_parquet_path(local_dir, episode_index, hf_dataset, info)
|
||||
_ = single_episode_parquet_path(
|
||||
local_dir, episode_index, hf_dataset, info
|
||||
)
|
||||
if rel_path == INFO_PATH:
|
||||
_ = info_path(local_dir, info)
|
||||
elif rel_path == STATS_PATH:
|
||||
|
||||
@@ -67,7 +67,9 @@ class GroupSyncRead:
|
||||
def addParam(self, motor_index): # noqa: N802
|
||||
# Initialize motor default values
|
||||
if motor_index not in self.packet_handler.data:
|
||||
self.packet_handler.data[motor_index] = get_default_motor_values(motor_index)
|
||||
self.packet_handler.data[motor_index] = get_default_motor_values(
|
||||
motor_index
|
||||
)
|
||||
|
||||
def txRxPacket(self): # noqa: N802
|
||||
return COMM_SUCCESS
|
||||
|
||||
@@ -17,7 +17,9 @@ class config: # noqa: N801
|
||||
def enable_device(self, device_id: str):
|
||||
self.device_enabled = device_id
|
||||
|
||||
def enable_stream(self, stream_type: stream, width=None, height=None, color_format=None, fps=None):
|
||||
def enable_stream(
|
||||
self, stream_type: stream, width=None, height=None, color_format=None, fps=None
|
||||
):
|
||||
self.stream_type = stream_type
|
||||
# Overwrite default values when possible
|
||||
self.width = 848 if width is None else width
|
||||
|
||||
@@ -77,7 +77,9 @@ class GroupSyncRead:
|
||||
def addParam(self, motor_index): # noqa: N802
|
||||
# Initialize motor default values
|
||||
if motor_index not in self.packet_handler.data:
|
||||
self.packet_handler.data[motor_index] = get_default_motor_values(motor_index)
|
||||
self.packet_handler.data[motor_index] = get_default_motor_values(
|
||||
motor_index
|
||||
)
|
||||
|
||||
def txRxPacket(self): # noqa: N802
|
||||
return COMM_SUCCESS
|
||||
|
||||
@@ -25,7 +25,10 @@ from torchmetrics import AUROC, Accuracy, F1Score, Precision, Recall
|
||||
from torchvision.datasets import CIFAR10
|
||||
from torchvision.transforms import ToTensor
|
||||
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier, ClassifierConfig
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
|
||||
Classifier,
|
||||
ClassifierConfig,
|
||||
)
|
||||
|
||||
BATCH_SIZE = 1000
|
||||
LR = 0.1
|
||||
@@ -43,7 +46,9 @@ def train_evaluate_multiclass_classifier():
|
||||
logging.info(
|
||||
f"Start multiclass classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}"
|
||||
)
|
||||
multiclass_config = ClassifierConfig(model_name="microsoft/resnet-18", device=DEVICE, num_classes=10)
|
||||
multiclass_config = ClassifierConfig(
|
||||
model_name="microsoft/resnet-18", device=DEVICE, num_classes=10
|
||||
)
|
||||
multiclass_classifier = Classifier(multiclass_config)
|
||||
|
||||
trainset = CIFAR10(root="data", train=True, download=True, transform=ToTensor())
|
||||
@@ -114,10 +119,18 @@ def train_evaluate_multiclass_classifier():
|
||||
test_probs = torch.stack(test_probs)
|
||||
|
||||
accuracy = Accuracy(task="multiclass", num_classes=multiclass_num_classes)
|
||||
precision = Precision(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
|
||||
recall = Recall(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
|
||||
f1 = F1Score(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
|
||||
auroc = AUROC(task="multiclass", num_classes=multiclass_num_classes, average="weighted")
|
||||
precision = Precision(
|
||||
task="multiclass", average="weighted", num_classes=multiclass_num_classes
|
||||
)
|
||||
recall = Recall(
|
||||
task="multiclass", average="weighted", num_classes=multiclass_num_classes
|
||||
)
|
||||
f1 = F1Score(
|
||||
task="multiclass", average="weighted", num_classes=multiclass_num_classes
|
||||
)
|
||||
auroc = AUROC(
|
||||
task="multiclass", num_classes=multiclass_num_classes, average="weighted"
|
||||
)
|
||||
|
||||
# Calculate metrics
|
||||
acc = accuracy(test_predictions, test_labels)
|
||||
@@ -146,18 +159,28 @@ def train_evaluate_binary_classifier():
|
||||
new_label = float(1.0) if label == target_class else float(0.0)
|
||||
new_targets.append(new_label)
|
||||
|
||||
dataset.targets = new_targets # Replace the original labels with the binary ones
|
||||
dataset.targets = (
|
||||
new_targets # Replace the original labels with the binary ones
|
||||
)
|
||||
return dataset
|
||||
|
||||
binary_train_dataset = CIFAR10(root="data", train=True, download=True, transform=ToTensor())
|
||||
binary_test_dataset = CIFAR10(root="data", train=False, download=True, transform=ToTensor())
|
||||
binary_train_dataset = CIFAR10(
|
||||
root="data", train=True, download=True, transform=ToTensor()
|
||||
)
|
||||
binary_test_dataset = CIFAR10(
|
||||
root="data", train=False, download=True, transform=ToTensor()
|
||||
)
|
||||
|
||||
# Apply one-vs-rest labeling
|
||||
binary_train_dataset = one_vs_rest(binary_train_dataset, target_binary_class)
|
||||
binary_test_dataset = one_vs_rest(binary_test_dataset, target_binary_class)
|
||||
|
||||
binary_trainloader = DataLoader(binary_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
binary_testloader = DataLoader(binary_test_dataset, batch_size=BATCH_SIZE, shuffle=False)
|
||||
binary_trainloader = DataLoader(
|
||||
binary_train_dataset, batch_size=BATCH_SIZE, shuffle=True
|
||||
)
|
||||
binary_testloader = DataLoader(
|
||||
binary_test_dataset, batch_size=BATCH_SIZE, shuffle=False
|
||||
)
|
||||
|
||||
binary_epoch = 1
|
||||
|
||||
|
||||
@@ -9,7 +9,9 @@ from tests.utils import require_package
|
||||
|
||||
def test_classifier_output():
|
||||
output = ClassifierOutput(
|
||||
logits=torch.tensor([1, 2, 3]), probabilities=torch.tensor([0.1, 0.2, 0.3]), hidden_states=None
|
||||
logits=torch.tensor([1, 2, 3]),
|
||||
probabilities=torch.tensor([0.1, 0.2, 0.3]),
|
||||
hidden_states=None,
|
||||
)
|
||||
|
||||
assert (
|
||||
@@ -20,7 +22,9 @@ def test_classifier_output():
|
||||
|
||||
@require_package("transformers")
|
||||
def test_binary_classifier_with_default_params():
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
|
||||
Classifier,
|
||||
)
|
||||
|
||||
config = ClassifierConfig()
|
||||
classifier = Classifier(config)
|
||||
@@ -41,7 +45,9 @@ def test_binary_classifier_with_default_params():
|
||||
|
||||
@require_package("transformers")
|
||||
def test_multiclass_classifier():
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
|
||||
Classifier,
|
||||
)
|
||||
|
||||
num_classes = 5
|
||||
config = ClassifierConfig(num_classes=num_classes)
|
||||
@@ -63,7 +69,9 @@ def test_multiclass_classifier():
|
||||
|
||||
@require_package("transformers")
|
||||
def test_default_device():
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
|
||||
Classifier,
|
||||
)
|
||||
|
||||
config = ClassifierConfig()
|
||||
assert config.device == "cpu"
|
||||
@@ -75,7 +83,9 @@ def test_default_device():
|
||||
|
||||
@require_package("transformers")
|
||||
def test_explicit_device_setup():
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
|
||||
Classifier,
|
||||
)
|
||||
|
||||
config = ClassifierConfig(device="meta")
|
||||
assert config.device == "meta"
|
||||
|
||||
@@ -51,7 +51,13 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
|
||||
save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors")
|
||||
|
||||
# save 2 frames at the middle of first episode
|
||||
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
|
||||
i = int(
|
||||
(
|
||||
dataset.episode_data_index["to"][0].item()
|
||||
- dataset.episode_data_index["from"][0].item()
|
||||
)
|
||||
/ 2
|
||||
)
|
||||
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
|
||||
save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors")
|
||||
|
||||
@@ -87,4 +93,6 @@ if __name__ == "__main__":
|
||||
"lerobot/nyu_franka_play_dataset",
|
||||
"lerobot/cmu_stretch",
|
||||
]:
|
||||
save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors", repo_id=dataset)
|
||||
save_dataset_to_safetensors(
|
||||
"tests/data/save_dataset_to_safetensors", repo_id=dataset
|
||||
)
|
||||
|
||||
@@ -67,7 +67,9 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
|
||||
param_stats = {}
|
||||
for key, param in policy.named_parameters():
|
||||
param_stats[f"{key}_mean"] = param.mean()
|
||||
param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(float(0.0))
|
||||
param_stats[f"{key}_std"] = (
|
||||
param.std() if param.numel() > 1 else torch.tensor(float(0.0))
|
||||
)
|
||||
|
||||
optimizer.zero_grad()
|
||||
policy.reset()
|
||||
@@ -85,11 +87,15 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
|
||||
else:
|
||||
actions_queue = cfg.policy.n_action_repeats
|
||||
|
||||
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
|
||||
actions = {
|
||||
str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)
|
||||
}
|
||||
return output_dict, grad_stats, param_stats, actions
|
||||
|
||||
|
||||
def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_overrides, file_name_extra):
|
||||
def save_policy_to_safetensors(
|
||||
output_dir, env_name, policy_name, extra_overrides, file_name_extra
|
||||
):
|
||||
env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}{file_name_extra}"
|
||||
|
||||
if env_policy_dir.exists():
|
||||
@@ -99,7 +105,9 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
|
||||
shutil.rmtree(env_policy_dir)
|
||||
|
||||
env_policy_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(
|
||||
env_name, policy_name, extra_overrides
|
||||
)
|
||||
save_file(output_dict, env_policy_dir / "output_dict.safetensors")
|
||||
save_file(grad_stats, env_policy_dir / "grad_stats.safetensors")
|
||||
save_file(param_stats, env_policy_dir / "param_stats.safetensors")
|
||||
@@ -129,5 +137,9 @@ if __name__ == "__main__":
|
||||
raise RuntimeError("No policies were provided!")
|
||||
for env, policy, extra_overrides, file_name_extra in env_policies:
|
||||
save_policy_to_safetensors(
|
||||
"tests/data/save_policy_to_safetensors", env, policy, extra_overrides, file_name_extra
|
||||
"tests/data/save_policy_to_safetensors",
|
||||
env,
|
||||
policy,
|
||||
extra_overrides,
|
||||
file_name_extra,
|
||||
)
|
||||
|
||||
@@ -24,7 +24,10 @@ pytest -sx 'tests/test_cameras.py::test_camera[intelrealsense-True]'
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceNotConnectedError,
|
||||
)
|
||||
from tests.utils import TEST_CAMERA_TYPES, make_camera, require_camera
|
||||
|
||||
# Maximum absolute difference between two consecutive images recored by a camera.
|
||||
@@ -97,7 +100,11 @@ def test_camera(request, camera_type, mock):
|
||||
)
|
||||
# TODO(rcadene): properly set `rtol`
|
||||
np.testing.assert_allclose(
|
||||
color_image, async_color_image, rtol=1e-5, atol=MAX_PIXEL_DIFFERENCE, err_msg=error_msg
|
||||
color_image,
|
||||
async_color_image,
|
||||
rtol=1e-5,
|
||||
atol=MAX_PIXEL_DIFFERENCE,
|
||||
err_msg=error_msg,
|
||||
)
|
||||
|
||||
# Test disconnecting
|
||||
@@ -116,7 +123,11 @@ def test_camera(request, camera_type, mock):
|
||||
assert camera.color_mode == "bgr"
|
||||
bgr_color_image = camera.read()
|
||||
np.testing.assert_allclose(
|
||||
color_image, bgr_color_image[:, :, [2, 1, 0]], rtol=1e-5, atol=MAX_PIXEL_DIFFERENCE, err_msg=error_msg
|
||||
color_image,
|
||||
bgr_color_image[:, :, [2, 1, 0]],
|
||||
rtol=1e-5,
|
||||
atol=MAX_PIXEL_DIFFERENCE,
|
||||
err_msg=error_msg,
|
||||
)
|
||||
del camera
|
||||
|
||||
@@ -151,7 +162,11 @@ def test_camera(request, camera_type, mock):
|
||||
rot_color_image = camera.read()
|
||||
|
||||
np.testing.assert_allclose(
|
||||
rot_color_image, manual_rot_img, rtol=1e-5, atol=MAX_PIXEL_DIFFERENCE, err_msg=error_msg
|
||||
rot_color_image,
|
||||
manual_rot_img,
|
||||
rtol=1e-5,
|
||||
atol=MAX_PIXEL_DIFFERENCE,
|
||||
err_msg=error_msg,
|
||||
)
|
||||
del camera
|
||||
|
||||
@@ -185,7 +200,9 @@ def test_save_images_from_cameras(tmpdir, request, camera_type, mock):
|
||||
if camera_type == "opencv":
|
||||
from lerobot.common.robot_devices.cameras.opencv import save_images_from_cameras
|
||||
elif camera_type == "intelrealsense":
|
||||
from lerobot.common.robot_devices.cameras.intelrealsense import save_images_from_cameras
|
||||
from lerobot.common.robot_devices.cameras.intelrealsense import (
|
||||
save_images_from_cameras,
|
||||
)
|
||||
|
||||
# Small `record_time_s` to speedup unit tests
|
||||
save_images_from_cameras(tmpdir, record_time_s=0.02, mock=mock)
|
||||
|
||||
@@ -35,7 +35,13 @@ from lerobot.common.utils.utils import init_hydra_config
|
||||
from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate
|
||||
from lerobot.scripts.train import make_optimizer_and_scheduler
|
||||
from tests.test_robots import make_robot
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
|
||||
from tests.utils import (
|
||||
DEFAULT_CONFIG_PATH,
|
||||
DEVICE,
|
||||
TEST_ROBOT_TYPES,
|
||||
mock_calibration_dir,
|
||||
require_robot,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
|
||||
@@ -158,7 +164,15 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||
assert dataset.meta.total_episodes == 2
|
||||
assert len(dataset) == 2
|
||||
|
||||
replay(robot, episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False, local_files_only=True)
|
||||
replay(
|
||||
robot,
|
||||
episode=0,
|
||||
fps=1,
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
play_sounds=False,
|
||||
local_files_only=True,
|
||||
)
|
||||
|
||||
# TODO(rcadene, aliberts): rethink this design
|
||||
if robot_type == "aloha":
|
||||
@@ -346,8 +360,12 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
|
||||
run_compute_stats=False,
|
||||
)
|
||||
|
||||
assert not mock_events["rerecord_episode"], "`rerecord_episode` wasn't properly reset to False"
|
||||
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
|
||||
assert not mock_events[
|
||||
"rerecord_episode"
|
||||
], "`rerecord_episode` wasn't properly reset to False"
|
||||
assert not mock_events[
|
||||
"exit_early"
|
||||
], "`exit_early` wasn't properly reset to False"
|
||||
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||
|
||||
|
||||
@@ -394,15 +412,20 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
|
||||
run_compute_stats=False,
|
||||
)
|
||||
|
||||
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
|
||||
assert not mock_events[
|
||||
"exit_early"
|
||||
], "`exit_early` wasn't properly reset to False"
|
||||
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"robot_type, mock, num_image_writer_processes", [("koch", True, 0), ("koch", True, 1)]
|
||||
"robot_type, mock, num_image_writer_processes",
|
||||
[("koch", True, 0), ("koch", True, 1)],
|
||||
)
|
||||
@require_robot
|
||||
def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num_image_writer_processes):
|
||||
def test_record_with_event_stop_recording(
|
||||
tmpdir, request, robot_type, mock, num_image_writer_processes
|
||||
):
|
||||
if mock:
|
||||
request.getfixturevalue("patch_builtins_input")
|
||||
|
||||
@@ -444,5 +467,7 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num
|
||||
num_image_writer_processes=num_image_writer_processes,
|
||||
)
|
||||
|
||||
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
|
||||
assert not mock_events[
|
||||
"exit_early"
|
||||
], "`exit_early` wasn't properly reset to False"
|
||||
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||
|
||||
@@ -56,7 +56,9 @@ def test_same_attributes_defined(lerobot_dataset_factory, tmp_path):
|
||||
# Instantiate both ways
|
||||
robot = make_robot("koch", mock=True)
|
||||
root_create = tmp_path / "create"
|
||||
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
|
||||
dataset_create = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create
|
||||
)
|
||||
|
||||
root_init = tmp_path / "init"
|
||||
dataset_init = lerobot_dataset_factory(root=root_init)
|
||||
@@ -102,7 +104,16 @@ def test_dataset_initialization(lerobot_dataset_factory, tmp_path):
|
||||
@pytest.mark.parametrize(
|
||||
"env_name, repo_id, policy_name",
|
||||
lerobot.env_dataset_policy_triplets
|
||||
+ [("aloha", ["lerobot/aloha_sim_insertion_human", "lerobot/aloha_sim_transfer_cube_human"], "act")],
|
||||
+ [
|
||||
(
|
||||
"aloha",
|
||||
[
|
||||
"lerobot/aloha_sim_insertion_human",
|
||||
"lerobot/aloha_sim_transfer_cube_human",
|
||||
],
|
||||
"act",
|
||||
)
|
||||
],
|
||||
)
|
||||
def test_factory(env_name, repo_id, policy_name):
|
||||
"""
|
||||
@@ -220,7 +231,9 @@ def test_compute_stats_on_xarm():
|
||||
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
|
||||
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
|
||||
# dataset into even batches.
|
||||
computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25), num_workers=0)
|
||||
computed_stats = compute_stats(
|
||||
dataset, batch_size=int(len(dataset) * 0.25), num_workers=0
|
||||
)
|
||||
|
||||
# get einops patterns to aggregate batches and compute statistics
|
||||
stats_patterns = get_stats_einops_patterns(dataset)
|
||||
@@ -241,7 +254,9 @@ def test_compute_stats_on_xarm():
|
||||
expected_stats[k] = {}
|
||||
expected_stats[k]["mean"] = einops.reduce(full_batch[k], pattern, "mean")
|
||||
expected_stats[k]["std"] = torch.sqrt(
|
||||
einops.reduce((full_batch[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")
|
||||
einops.reduce(
|
||||
(full_batch[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean"
|
||||
)
|
||||
)
|
||||
expected_stats[k]["min"] = einops.reduce(full_batch[k], pattern, "min")
|
||||
expected_stats[k]["max"] = einops.reduce(full_batch[k], pattern, "max")
|
||||
@@ -286,7 +301,9 @@ def test_flatten_unflatten_dict():
|
||||
d = unflatten_dict(flatten_dict(d))
|
||||
|
||||
# test equality between nested dicts
|
||||
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
|
||||
assert json.dumps(original_d, sort_keys=True) == json.dumps(
|
||||
d, sort_keys=True
|
||||
), f"{original_d} != {d}"
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@@ -333,7 +350,13 @@ def test_backward_compatibility(repo_id):
|
||||
load_and_compare(i + 1)
|
||||
|
||||
# test 2 frames at the middle of first episode
|
||||
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
|
||||
i = int(
|
||||
(
|
||||
dataset.episode_data_index["to"][0].item()
|
||||
- dataset.episode_data_index["from"][0].item()
|
||||
)
|
||||
/ 2
|
||||
)
|
||||
load_and_compare(i)
|
||||
load_and_compare(i + 1)
|
||||
|
||||
@@ -370,23 +393,40 @@ def test_aggregate_stats():
|
||||
data_c = torch.rand(20, dtype=torch.float32)
|
||||
|
||||
hf_dataset_1 = Dataset.from_dict(
|
||||
{"a": data_a[:10], "b": data_b[:10], "c": data_c[:10], "index": torch.arange(10)}
|
||||
{
|
||||
"a": data_a[:10],
|
||||
"b": data_b[:10],
|
||||
"c": data_c[:10],
|
||||
"index": torch.arange(10),
|
||||
}
|
||||
)
|
||||
hf_dataset_1.set_transform(hf_transform_to_torch)
|
||||
hf_dataset_2 = Dataset.from_dict({"a": data_a[10:20], "b": data_b[10:], "index": torch.arange(10)})
|
||||
hf_dataset_2 = Dataset.from_dict(
|
||||
{"a": data_a[10:20], "b": data_b[10:], "index": torch.arange(10)}
|
||||
)
|
||||
hf_dataset_2.set_transform(hf_transform_to_torch)
|
||||
hf_dataset_3 = Dataset.from_dict({"a": data_a[20:], "c": data_c[10:], "index": torch.arange(10)})
|
||||
hf_dataset_3 = Dataset.from_dict(
|
||||
{"a": data_a[20:], "c": data_c[10:], "index": torch.arange(10)}
|
||||
)
|
||||
hf_dataset_3.set_transform(hf_transform_to_torch)
|
||||
dataset_1 = LeRobotDataset.from_preloaded("d1", hf_dataset=hf_dataset_1)
|
||||
dataset_1.stats = compute_stats(dataset_1, batch_size=len(hf_dataset_1), num_workers=0)
|
||||
dataset_1.stats = compute_stats(
|
||||
dataset_1, batch_size=len(hf_dataset_1), num_workers=0
|
||||
)
|
||||
dataset_2 = LeRobotDataset.from_preloaded("d2", hf_dataset=hf_dataset_2)
|
||||
dataset_2.stats = compute_stats(dataset_2, batch_size=len(hf_dataset_2), num_workers=0)
|
||||
dataset_2.stats = compute_stats(
|
||||
dataset_2, batch_size=len(hf_dataset_2), num_workers=0
|
||||
)
|
||||
dataset_3 = LeRobotDataset.from_preloaded("d3", hf_dataset=hf_dataset_3)
|
||||
dataset_3.stats = compute_stats(dataset_3, batch_size=len(hf_dataset_3), num_workers=0)
|
||||
dataset_3.stats = compute_stats(
|
||||
dataset_3, batch_size=len(hf_dataset_3), num_workers=0
|
||||
)
|
||||
stats = aggregate_stats([dataset_1, dataset_2, dataset_3])
|
||||
for data_key, data in zip(["a", "b", "c"], [data_a, data_b, data_c], strict=True):
|
||||
for agg_fn in ["mean", "min", "max"]:
|
||||
assert torch.allclose(stats[data_key][agg_fn], einops.reduce(data, "n -> 1", agg_fn))
|
||||
assert torch.allclose(
|
||||
stats[data_key][agg_fn], einops.reduce(data, "n -> 1", agg_fn)
|
||||
)
|
||||
assert torch.allclose(stats[data_key]["std"], torch.std(data, correction=0))
|
||||
|
||||
|
||||
|
||||
@@ -22,13 +22,17 @@ def synced_hf_dataset_factory(hf_dataset_factory):
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def unsynced_hf_dataset_factory(synced_hf_dataset_factory):
|
||||
def _create_unsynced_hf_dataset(fps: int = 30, tolerance_s: float = 1e-4) -> Dataset:
|
||||
def _create_unsynced_hf_dataset(
|
||||
fps: int = 30, tolerance_s: float = 1e-4
|
||||
) -> Dataset:
|
||||
hf_dataset = synced_hf_dataset_factory(fps=fps)
|
||||
features = hf_dataset.features
|
||||
df = hf_dataset.to_pandas()
|
||||
dtype = df["timestamp"].dtype # This is to avoid pandas type warning
|
||||
# Modify a single timestamp just outside tolerance
|
||||
df.at[30, "timestamp"] = dtype.type(df.at[30, "timestamp"] + (tolerance_s * 1.1))
|
||||
df.at[30, "timestamp"] = dtype.type(
|
||||
df.at[30, "timestamp"] + (tolerance_s * 1.1)
|
||||
)
|
||||
unsynced_hf_dataset = Dataset.from_pandas(df, features=features)
|
||||
unsynced_hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return unsynced_hf_dataset
|
||||
@@ -38,13 +42,17 @@ def unsynced_hf_dataset_factory(synced_hf_dataset_factory):
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def slightly_off_hf_dataset_factory(synced_hf_dataset_factory):
|
||||
def _create_slightly_off_hf_dataset(fps: int = 30, tolerance_s: float = 1e-4) -> Dataset:
|
||||
def _create_slightly_off_hf_dataset(
|
||||
fps: int = 30, tolerance_s: float = 1e-4
|
||||
) -> Dataset:
|
||||
hf_dataset = synced_hf_dataset_factory(fps=fps)
|
||||
features = hf_dataset.features
|
||||
df = hf_dataset.to_pandas()
|
||||
dtype = df["timestamp"].dtype # This is to avoid pandas type warning
|
||||
# Modify a single timestamp just inside tolerance
|
||||
df.at[30, "timestamp"] = dtype.type(df.at[30, "timestamp"] + (tolerance_s * 0.9))
|
||||
df.at[30, "timestamp"] = dtype.type(
|
||||
df.at[30, "timestamp"] + (tolerance_s * 0.9)
|
||||
)
|
||||
unsynced_hf_dataset = Dataset.from_pandas(df, features=features)
|
||||
unsynced_hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return unsynced_hf_dataset
|
||||
@@ -54,8 +62,12 @@ def slightly_off_hf_dataset_factory(synced_hf_dataset_factory):
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def valid_delta_timestamps_factory():
|
||||
def _create_valid_delta_timestamps(fps: int = 30, keys: list = DUMMY_MOTOR_FEATURES) -> dict:
|
||||
delta_timestamps = {key: [i * (1 / fps) for i in range(-10, 10)] for key in keys}
|
||||
def _create_valid_delta_timestamps(
|
||||
fps: int = 30, keys: list = DUMMY_MOTOR_FEATURES
|
||||
) -> dict:
|
||||
delta_timestamps = {
|
||||
key: [i * (1 / fps) for i in range(-10, 10)] for key in keys
|
||||
}
|
||||
return delta_timestamps
|
||||
|
||||
return _create_valid_delta_timestamps
|
||||
@@ -153,7 +165,9 @@ def test_check_timestamps_sync_slightly_off(slightly_off_hf_dataset_factory):
|
||||
|
||||
|
||||
def test_check_timestamps_sync_single_timestamp():
|
||||
single_timestamp_hf_dataset = Dataset.from_dict({"timestamp": [0.0], "episode_index": [0]})
|
||||
single_timestamp_hf_dataset = Dataset.from_dict(
|
||||
{"timestamp": [0.0], "episode_index": [0]}
|
||||
)
|
||||
single_timestamp_hf_dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = {"to": torch.tensor([1]), "from": torch.tensor([0])}
|
||||
fps = 30
|
||||
@@ -202,7 +216,9 @@ def test_check_delta_timestamps_valid(valid_delta_timestamps_factory):
|
||||
def test_check_delta_timestamps_slightly_off(slightly_off_delta_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
slightly_off_delta_timestamps = slightly_off_delta_timestamps_factory(fps, tolerance_s)
|
||||
slightly_off_delta_timestamps = slightly_off_delta_timestamps_factory(
|
||||
fps, tolerance_s
|
||||
)
|
||||
result = check_delta_timestamps(
|
||||
delta_timestamps=slightly_off_delta_timestamps,
|
||||
fps=fps,
|
||||
|
||||
@@ -21,7 +21,11 @@ from safetensors.torch import load_file
|
||||
from torchvision.transforms import v2
|
||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
|
||||
from lerobot.common.datasets.transforms import RandomSubsetApply, SharpnessJitter, get_image_transforms
|
||||
from lerobot.common.datasets.transforms import (
|
||||
RandomSubsetApply,
|
||||
SharpnessJitter,
|
||||
get_image_transforms,
|
||||
)
|
||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||
from lerobot.scripts.visualize_image_transforms import visualize_transforms
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, require_x86_64_kernel
|
||||
@@ -51,7 +55,9 @@ def default_transforms():
|
||||
|
||||
def test_get_image_transforms_no_transform(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(brightness_min_max=(0.5, 0.5), max_num_transforms=0)
|
||||
tf_actual = get_image_transforms(
|
||||
brightness_min_max=(0.5, 0.5), max_num_transforms=0
|
||||
)
|
||||
torch.testing.assert_close(tf_actual(img_tensor), img_tensor)
|
||||
|
||||
|
||||
@@ -149,7 +155,9 @@ def test_get_image_transforms_random_order(img_tensor_factory):
|
||||
("sharpness", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
],
|
||||
)
|
||||
def test_backward_compatibility_torchvision(img_tensor_factory, transform, min_max_values, single_transforms):
|
||||
def test_backward_compatibility_torchvision(
|
||||
img_tensor_factory, transform, min_max_values, single_transforms
|
||||
):
|
||||
img_tensor = img_tensor_factory()
|
||||
for min_max in min_max_values:
|
||||
kwargs = {
|
||||
@@ -268,23 +276,33 @@ def test_sharpness_jitter_invalid_range_max_smaller():
|
||||
],
|
||||
)
|
||||
def test_visualize_image_transforms(repo_id, n_examples):
|
||||
cfg = init_hydra_config(DEFAULT_CONFIG_PATH, overrides=[f"dataset_repo_id={repo_id}"])
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH, overrides=[f"dataset_repo_id={repo_id}"]
|
||||
)
|
||||
output_dir = Path(__file__).parent / "outputs" / "image_transforms"
|
||||
visualize_transforms(cfg, output_dir=output_dir, n_examples=n_examples)
|
||||
output_dir = output_dir / repo_id.split("/")[-1]
|
||||
|
||||
# Check if the original frame image exists
|
||||
assert (output_dir / "original_frame.png").exists(), "Original frame image was not saved."
|
||||
assert (
|
||||
output_dir / "original_frame.png"
|
||||
).exists(), "Original frame image was not saved."
|
||||
|
||||
# Check if the transformed images exist for each transform type
|
||||
transforms = ["brightness", "contrast", "saturation", "hue", "sharpness"]
|
||||
for transform in transforms:
|
||||
transform_dir = output_dir / transform
|
||||
assert transform_dir.exists(), f"{transform} directory was not created."
|
||||
assert any(transform_dir.iterdir()), f"No transformed images found in {transform} directory."
|
||||
assert any(
|
||||
transform_dir.iterdir()
|
||||
), f"No transformed images found in {transform} directory."
|
||||
|
||||
# Check for specific files within each transform directory
|
||||
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + ["min.png", "max.png", "mean.png"]
|
||||
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + [
|
||||
"min.png",
|
||||
"max.png",
|
||||
"mean.png",
|
||||
]
|
||||
for file_name in expected_files:
|
||||
assert (
|
||||
transform_dir / file_name
|
||||
@@ -292,7 +310,9 @@ def test_visualize_image_transforms(repo_id, n_examples):
|
||||
|
||||
# Check if the combined transforms directory exists and contains the right files
|
||||
combined_transforms_dir = output_dir / "all"
|
||||
assert combined_transforms_dir.exists(), "Combined transforms directory was not created."
|
||||
assert (
|
||||
combined_transforms_dir.exists()
|
||||
), "Combined transforms directory was not created."
|
||||
assert any(
|
||||
combined_transforms_dir.iterdir()
|
||||
), "No transformed images found in combined transforms directory."
|
||||
|
||||
@@ -160,7 +160,9 @@ def test_save_image_torch(tmp_path, img_tensor_factory):
|
||||
writer.wait_until_done()
|
||||
assert fpath.exists()
|
||||
saved_image = np.array(Image.open(fpath))
|
||||
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
||||
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(
|
||||
np.uint8
|
||||
)
|
||||
assert np.array_equal(expected_image, saved_image)
|
||||
finally:
|
||||
writer.stop()
|
||||
@@ -175,7 +177,9 @@ def test_save_image_torch_multiprocessing(tmp_path, img_tensor_factory):
|
||||
writer.wait_until_done()
|
||||
assert fpath.exists()
|
||||
saved_image = np.array(Image.open(fpath))
|
||||
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
||||
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(
|
||||
np.uint8
|
||||
)
|
||||
assert np.array_equal(expected_image, saved_image)
|
||||
finally:
|
||||
writer.stop()
|
||||
@@ -265,7 +269,9 @@ def test_wait_until_done(tmp_path, img_array_factory):
|
||||
writer = AsyncImageWriter(num_processes=0, num_threads=4)
|
||||
try:
|
||||
num_images = 100
|
||||
image_arrays = [img_array_factory(height=500, width=500) for _ in range(num_images)]
|
||||
image_arrays = [
|
||||
img_array_factory(height=500, width=500) for _ in range(num_images)
|
||||
]
|
||||
fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)]
|
||||
for image_array, fpath in zip(image_arrays, fpaths, strict=True):
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -30,7 +30,10 @@ import time
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceNotConnectedError,
|
||||
)
|
||||
from lerobot.scripts.find_motors_bus_port import find_port
|
||||
from tests.utils import TEST_MOTOR_TYPES, make_motors_bus, require_motor
|
||||
|
||||
@@ -63,7 +66,9 @@ def test_configure_motors_all_ids_1(request, motor_type, mock):
|
||||
else:
|
||||
raise ValueError(motor_type)
|
||||
|
||||
input("Are you sure you want to re-configure the motors? Press enter to continue...")
|
||||
input(
|
||||
"Are you sure you want to re-configure the motors? Press enter to continue..."
|
||||
)
|
||||
# This test expect the configuration was already correct.
|
||||
motors_bus = make_motors_bus(motor_type, mock=mock)
|
||||
motors_bus.connect()
|
||||
|
||||
@@ -44,13 +44,23 @@ def make_new_buffer(
|
||||
return buffer, write_dir
|
||||
|
||||
|
||||
def make_spoof_data_frames(n_episodes: int, n_frames_per_episode: int) -> dict[str, np.ndarray]:
|
||||
def make_spoof_data_frames(
|
||||
n_episodes: int, n_frames_per_episode: int
|
||||
) -> dict[str, np.ndarray]:
|
||||
new_data = {
|
||||
data_key: np.arange(n_frames_per_episode * n_episodes * np.prod(data_shape)).reshape(-1, *data_shape),
|
||||
data_key: np.arange(
|
||||
n_frames_per_episode * n_episodes * np.prod(data_shape)
|
||||
).reshape(-1, *data_shape),
|
||||
OnlineBuffer.INDEX_KEY: np.arange(n_frames_per_episode * n_episodes),
|
||||
OnlineBuffer.EPISODE_INDEX_KEY: np.repeat(np.arange(n_episodes), n_frames_per_episode),
|
||||
OnlineBuffer.FRAME_INDEX_KEY: np.tile(np.arange(n_frames_per_episode), n_episodes),
|
||||
OnlineBuffer.TIMESTAMP_KEY: np.tile(np.arange(n_frames_per_episode) / fps, n_episodes),
|
||||
OnlineBuffer.EPISODE_INDEX_KEY: np.repeat(
|
||||
np.arange(n_episodes), n_frames_per_episode
|
||||
),
|
||||
OnlineBuffer.FRAME_INDEX_KEY: np.tile(
|
||||
np.arange(n_frames_per_episode), n_episodes
|
||||
),
|
||||
OnlineBuffer.TIMESTAMP_KEY: np.tile(
|
||||
np.arange(n_frames_per_episode) / fps, n_episodes
|
||||
),
|
||||
}
|
||||
return new_data
|
||||
|
||||
@@ -133,8 +143,8 @@ def test_fifo():
|
||||
n_more_episodes = 2
|
||||
# Developer sanity check (in case someone changes the global `buffer_capacity`).
|
||||
assert (
|
||||
n_episodes + n_more_episodes
|
||||
) * n_frames_per_episode > buffer_capacity, "Something went wrong with the test code."
|
||||
(n_episodes + n_more_episodes) * n_frames_per_episode > buffer_capacity
|
||||
), "Something went wrong with the test code."
|
||||
more_new_data = make_spoof_data_frames(n_more_episodes, n_frames_per_episode)
|
||||
buffer.add_data(more_new_data)
|
||||
assert len(buffer) == buffer_capacity, "The buffer should be full."
|
||||
@@ -166,7 +176,9 @@ def test_delta_timestamps_within_tolerance():
|
||||
buffer.tolerance_s = 0.04
|
||||
item = buffer[2]
|
||||
data, is_pad = item["index"], item[f"index{OnlineBuffer.IS_PAD_POSTFIX}"]
|
||||
assert torch.allclose(data, torch.tensor([0, 2, 3])), "Data does not match expected values"
|
||||
assert torch.allclose(
|
||||
data, torch.tensor([0, 2, 3])
|
||||
), "Data does not match expected values"
|
||||
assert not is_pad.any(), "Unexpected padding detected"
|
||||
|
||||
|
||||
@@ -202,7 +214,9 @@ def test_delta_timestamps_outside_tolerance_outside_episode_range():
|
||||
buffer.tolerance_s = 0.04
|
||||
item = buffer[2]
|
||||
data, is_pad = item["index"], item["index_is_pad"]
|
||||
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
|
||||
assert torch.equal(
|
||||
data, torch.tensor([0, 0, 2, 4, 4])
|
||||
), "Data does not match expected values"
|
||||
assert torch.equal(
|
||||
is_pad, torch.tensor([True, False, False, True, True])
|
||||
), "Padding does not match expected values"
|
||||
@@ -219,58 +233,89 @@ def test_compute_sampler_weights_trivial(
|
||||
online_dataset_size: int,
|
||||
online_sampling_ratio: float,
|
||||
):
|
||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=offline_dataset_size)
|
||||
offline_dataset = lerobot_dataset_factory(
|
||||
tmp_path, total_episodes=1, total_frames=offline_dataset_size
|
||||
)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
if online_dataset_size > 0:
|
||||
online_dataset.add_data(
|
||||
make_spoof_data_frames(n_episodes=2, n_frames_per_episode=online_dataset_size // 2)
|
||||
make_spoof_data_frames(
|
||||
n_episodes=2, n_frames_per_episode=online_dataset_size // 2
|
||||
)
|
||||
)
|
||||
|
||||
weights = compute_sampler_weights(
|
||||
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
|
||||
offline_dataset,
|
||||
online_dataset=online_dataset,
|
||||
online_sampling_ratio=online_sampling_ratio,
|
||||
)
|
||||
if offline_dataset_size == 0 or online_dataset_size == 0:
|
||||
expected_weights = torch.ones(offline_dataset_size + online_dataset_size)
|
||||
elif online_sampling_ratio == 0:
|
||||
expected_weights = torch.cat([torch.ones(offline_dataset_size), torch.zeros(online_dataset_size)])
|
||||
expected_weights = torch.cat(
|
||||
[torch.ones(offline_dataset_size), torch.zeros(online_dataset_size)]
|
||||
)
|
||||
elif online_sampling_ratio == 1:
|
||||
expected_weights = torch.cat([torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)])
|
||||
expected_weights = torch.cat(
|
||||
[torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)]
|
||||
)
|
||||
expected_weights /= expected_weights.sum()
|
||||
assert torch.allclose(weights, expected_weights)
|
||||
|
||||
|
||||
def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path):
|
||||
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
|
||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
|
||||
offline_dataset = lerobot_dataset_factory(
|
||||
tmp_path, total_episodes=1, total_frames=4
|
||||
)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
online_dataset.add_data(
|
||||
make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)
|
||||
)
|
||||
online_sampling_ratio = 0.8
|
||||
weights = compute_sampler_weights(
|
||||
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
|
||||
offline_dataset,
|
||||
online_dataset=online_dataset,
|
||||
online_sampling_ratio=online_sampling_ratio,
|
||||
)
|
||||
assert torch.allclose(
|
||||
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
|
||||
weights,
|
||||
torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]),
|
||||
)
|
||||
|
||||
|
||||
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_dataset_factory, tmp_path):
|
||||
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(
|
||||
lerobot_dataset_factory, tmp_path
|
||||
):
|
||||
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
|
||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
|
||||
offline_dataset = lerobot_dataset_factory(
|
||||
tmp_path, total_episodes=1, total_frames=4
|
||||
)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
online_dataset.add_data(
|
||||
make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)
|
||||
)
|
||||
weights = compute_sampler_weights(
|
||||
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=0.8, online_drop_n_last_frames=1
|
||||
offline_dataset,
|
||||
online_dataset=online_dataset,
|
||||
online_sampling_ratio=0.8,
|
||||
online_drop_n_last_frames=1,
|
||||
)
|
||||
assert torch.allclose(
|
||||
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0])
|
||||
weights,
|
||||
torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0]),
|
||||
)
|
||||
|
||||
|
||||
def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp_path):
|
||||
"""Note: test copied from test_sampler."""
|
||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=2)
|
||||
offline_dataset = lerobot_dataset_factory(
|
||||
tmp_path, total_episodes=1, total_frames=2
|
||||
)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
online_dataset.add_data(
|
||||
make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2)
|
||||
)
|
||||
|
||||
weights = compute_sampler_weights(
|
||||
offline_dataset,
|
||||
@@ -279,4 +324,6 @@ def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp
|
||||
online_sampling_ratio=0.5,
|
||||
online_drop_n_last_frames=1,
|
||||
)
|
||||
assert torch.allclose(weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0]))
|
||||
assert torch.allclose(
|
||||
weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0])
|
||||
)
|
||||
|
||||
@@ -39,7 +39,13 @@ from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||
from lerobot.scripts.train import make_optimizer_and_scheduler
|
||||
from tests.scripts.save_policy_to_safetensors import get_policy_stats
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
|
||||
from tests.utils import (
|
||||
DEFAULT_CONFIG_PATH,
|
||||
DEVICE,
|
||||
require_cpu,
|
||||
require_env,
|
||||
require_x86_64_kernel,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("policy_name", available_policies)
|
||||
@@ -47,37 +53,63 @@ def test_get_policy_and_config_classes(policy_name: str):
|
||||
"""Check that the correct policy and config classes are returned."""
|
||||
policy_cls, config_cls = get_policy_and_config_classes(policy_name)
|
||||
assert policy_cls.name == policy_name
|
||||
assert issubclass(config_cls, inspect.signature(policy_cls.__init__).parameters["config"].annotation)
|
||||
assert issubclass(
|
||||
config_cls,
|
||||
inspect.signature(policy_cls.__init__).parameters["config"].annotation,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@pytest.mark.parametrize(
|
||||
"env_name,policy_name,extra_overrides",
|
||||
[
|
||||
("xarm", "tdmpc", ["policy.use_mpc=true", "dataset_repo_id=lerobot/xarm_lift_medium"]),
|
||||
(
|
||||
"xarm",
|
||||
"tdmpc",
|
||||
["policy.use_mpc=true", "dataset_repo_id=lerobot/xarm_lift_medium"],
|
||||
),
|
||||
("pusht", "diffusion", []),
|
||||
("pusht", "vqbet", []),
|
||||
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_human"]),
|
||||
(
|
||||
"aloha",
|
||||
"act",
|
||||
["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_scripted"],
|
||||
[
|
||||
"env.task=AlohaInsertion-v0",
|
||||
"dataset_repo_id=lerobot/aloha_sim_insertion_human",
|
||||
],
|
||||
),
|
||||
(
|
||||
"aloha",
|
||||
"act",
|
||||
["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_human"],
|
||||
[
|
||||
"env.task=AlohaInsertion-v0",
|
||||
"dataset_repo_id=lerobot/aloha_sim_insertion_scripted",
|
||||
],
|
||||
),
|
||||
(
|
||||
"aloha",
|
||||
"act",
|
||||
["env.task=AlohaTransferCube-v0", "dataset_repo_id=lerobot/aloha_sim_transfer_cube_scripted"],
|
||||
[
|
||||
"env.task=AlohaTransferCube-v0",
|
||||
"dataset_repo_id=lerobot/aloha_sim_transfer_cube_human",
|
||||
],
|
||||
),
|
||||
(
|
||||
"aloha",
|
||||
"act",
|
||||
[
|
||||
"env.task=AlohaTransferCube-v0",
|
||||
"dataset_repo_id=lerobot/aloha_sim_transfer_cube_scripted",
|
||||
],
|
||||
),
|
||||
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
|
||||
(
|
||||
"aloha",
|
||||
"diffusion",
|
||||
["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_human"],
|
||||
[
|
||||
"env.task=AlohaInsertion-v0",
|
||||
"dataset_repo_id=lerobot/aloha_sim_insertion_human",
|
||||
],
|
||||
),
|
||||
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
|
||||
("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]),
|
||||
@@ -165,7 +197,9 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
# Test updating the policy (and test that it does not mutate the batch)
|
||||
batch_ = deepcopy(batch)
|
||||
policy.forward(batch)
|
||||
assert set(batch) == set(batch_), "Batch keys are not the same after a forward pass."
|
||||
assert set(batch) == set(
|
||||
batch_
|
||||
), "Batch keys are not the same after a forward pass."
|
||||
assert all(
|
||||
torch.equal(batch[k], batch_[k]) for k in batch
|
||||
), "Batch values are not the same after a forward pass."
|
||||
@@ -178,7 +212,9 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
observation = preprocess_observation(observation)
|
||||
|
||||
# send observation to device/gpu
|
||||
observation = {key: observation[key].to(DEVICE, non_blocking=True) for key in observation}
|
||||
observation = {
|
||||
key: observation[key].to(DEVICE, non_blocking=True) for key in observation
|
||||
}
|
||||
|
||||
# get the next action for the environment (also check that the observation batch is not modified)
|
||||
observation_ = deepcopy(observation)
|
||||
@@ -240,7 +276,9 @@ def test_policy_defaults(policy_name: str):
|
||||
)
|
||||
def test_yaml_matches_dataclass(env_name: str, policy_name: str):
|
||||
"""Check that dataclass configs match their respective yaml configs."""
|
||||
hydra_cfg = init_hydra_config(DEFAULT_CONFIG_PATH, overrides=[f"env={env_name}", f"policy={policy_name}"])
|
||||
hydra_cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH, overrides=[f"env={env_name}", f"policy={policy_name}"]
|
||||
)
|
||||
_, policy_cfg_cls = get_policy_and_config_classes(policy_name)
|
||||
policy_cfg_from_hydra = _policy_cfg_from_hydra_cfg(policy_cfg_cls, hydra_cfg)
|
||||
policy_cfg_from_dataclass = policy_cfg_cls()
|
||||
@@ -254,7 +292,10 @@ def test_save_and_load_pretrained(policy_name: str):
|
||||
save_dir = "/tmp/test_save_and_load_pretrained_{policy_cls.__name__}"
|
||||
policy.save_pretrained(save_dir)
|
||||
policy_ = policy_cls.from_pretrained(save_dir)
|
||||
assert all(torch.equal(p, p_) for p, p_ in zip(policy.parameters(), policy_.parameters(), strict=True))
|
||||
assert all(
|
||||
torch.equal(p, p_)
|
||||
for p, p_ in zip(policy.parameters(), policy_.parameters(), strict=True)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("insert_temporal_dim", [False, True])
|
||||
@@ -343,7 +384,9 @@ def test_normalize(insert_temporal_dim):
|
||||
unnormalize(output_batch)
|
||||
|
||||
# test with stats
|
||||
unnormalize = Unnormalize(output_shapes, unnormalize_output_modes, stats=dataset_stats)
|
||||
unnormalize = Unnormalize(
|
||||
output_shapes, unnormalize_output_modes, stats=dataset_stats
|
||||
)
|
||||
unnormalize(output_batch)
|
||||
|
||||
# test loading pretrained models
|
||||
@@ -364,11 +407,20 @@ def test_normalize(insert_temporal_dim):
|
||||
(
|
||||
"pusht",
|
||||
"diffusion",
|
||||
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
||||
[
|
||||
"policy.n_action_steps=8",
|
||||
"policy.num_inference_steps=10",
|
||||
"policy.down_dims=[128, 256, 512]",
|
||||
],
|
||||
"",
|
||||
),
|
||||
("aloha", "act", ["policy.n_action_steps=10"], ""),
|
||||
("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
|
||||
(
|
||||
"aloha",
|
||||
"act",
|
||||
["policy.n_action_steps=1000", "policy.chunk_size=1000"],
|
||||
"_1000_steps",
|
||||
),
|
||||
("dora_aloha_real", "act_aloha_real", ["policy.n_action_steps=10"], ""),
|
||||
],
|
||||
)
|
||||
@@ -376,7 +428,9 @@ def test_normalize(insert_temporal_dim):
|
||||
# pass if it's run on another platform due to floating point errors
|
||||
@require_x86_64_kernel
|
||||
@require_cpu
|
||||
def test_backward_compatibility(env_name, policy_name, extra_overrides, file_name_extra):
|
||||
def test_backward_compatibility(
|
||||
env_name, policy_name, extra_overrides, file_name_extra
|
||||
):
|
||||
"""
|
||||
NOTE: If this test does not pass, and you have intentionally changed something in the policy:
|
||||
1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should
|
||||
@@ -390,23 +444,34 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides, file_nam
|
||||
6. Remember to stage and commit the resulting changes to `tests/data`.
|
||||
"""
|
||||
env_policy_dir = (
|
||||
Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}{file_name_extra}"
|
||||
Path("tests/data/save_policy_to_safetensors")
|
||||
/ f"{env_name}_{policy_name}{file_name_extra}"
|
||||
)
|
||||
saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors")
|
||||
saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors")
|
||||
saved_param_stats = load_file(env_policy_dir / "param_stats.safetensors")
|
||||
saved_actions = load_file(env_policy_dir / "actions.safetensors")
|
||||
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(env_name, policy_name, extra_overrides)
|
||||
output_dict, grad_stats, param_stats, actions = get_policy_stats(
|
||||
env_name, policy_name, extra_overrides
|
||||
)
|
||||
|
||||
for key in saved_output_dict:
|
||||
assert torch.isclose(output_dict[key], saved_output_dict[key], rtol=0.1, atol=1e-7).all()
|
||||
assert torch.isclose(
|
||||
output_dict[key], saved_output_dict[key], rtol=0.1, atol=1e-7
|
||||
).all()
|
||||
for key in saved_grad_stats:
|
||||
assert torch.isclose(grad_stats[key], saved_grad_stats[key], rtol=0.1, atol=1e-7).all()
|
||||
assert torch.isclose(
|
||||
grad_stats[key], saved_grad_stats[key], rtol=0.1, atol=1e-7
|
||||
).all()
|
||||
for key in saved_param_stats:
|
||||
assert torch.isclose(param_stats[key], saved_param_stats[key], rtol=50, atol=1e-7).all()
|
||||
assert torch.isclose(
|
||||
param_stats[key], saved_param_stats[key], rtol=50, atol=1e-7
|
||||
).all()
|
||||
for key in saved_actions:
|
||||
assert torch.isclose(actions[key], saved_actions[key], rtol=0.1, atol=1e-7).all()
|
||||
assert torch.isclose(
|
||||
actions[key], saved_actions[key], rtol=0.1, atol=1e-7
|
||||
).all()
|
||||
|
||||
|
||||
def test_act_temporal_ensembler():
|
||||
@@ -432,7 +497,9 @@ def test_act_temporal_ensembler():
|
||||
batch_size = batch_seq.shape[0]
|
||||
# Exponential weighting (normalized). Unsqueeze once to match the position of the `episode_length`
|
||||
# dimension of `batch_seq`.
|
||||
weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)).unsqueeze(-1)
|
||||
weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)).unsqueeze(
|
||||
-1
|
||||
)
|
||||
|
||||
# Simulate stepping through a rollout and computing a batch of actions with model on each step.
|
||||
for i in range(episode_length):
|
||||
@@ -455,7 +522,8 @@ def test_act_temporal_ensembler():
|
||||
episode_step_indices = torch.arange(i + 1)[-len(chunk_indices) :]
|
||||
seq_slice = batch_seq[:, episode_step_indices, chunk_indices]
|
||||
offline_avg = (
|
||||
einops.reduce(seq_slice * weights[: i + 1], "b s 1 -> b 1", "sum") / weights[: i + 1].sum()
|
||||
einops.reduce(seq_slice * weights[: i + 1], "b s 1 -> b 1", "sum")
|
||||
/ weights[: i + 1].sum()
|
||||
)
|
||||
# Sanity check. The average should be between the extrema.
|
||||
assert torch.all(einops.reduce(seq_slice, "b s 1 -> b 1", "min") <= offline_avg)
|
||||
|
||||
@@ -31,7 +31,11 @@ def _mock_download_raw_pusht(raw_dir, num_frames=4, num_episodes=3):
|
||||
zarr_data = zarr.group(store=store)
|
||||
|
||||
zarr_data.create_dataset(
|
||||
"data/action", shape=(num_frames, 1), chunks=(num_frames, 1), dtype=np.float32, overwrite=True
|
||||
"data/action",
|
||||
shape=(num_frames, 1),
|
||||
chunks=(num_frames, 1),
|
||||
dtype=np.float32,
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/img",
|
||||
@@ -41,20 +45,38 @@ def _mock_download_raw_pusht(raw_dir, num_frames=4, num_episodes=3):
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/n_contacts", shape=(num_frames, 2), chunks=(num_frames, 2), dtype=np.float32, overwrite=True
|
||||
"data/n_contacts",
|
||||
shape=(num_frames, 2),
|
||||
chunks=(num_frames, 2),
|
||||
dtype=np.float32,
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/state", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True
|
||||
"data/state",
|
||||
shape=(num_frames, 5),
|
||||
chunks=(num_frames, 5),
|
||||
dtype=np.float32,
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/keypoint", shape=(num_frames, 9, 2), chunks=(num_frames, 9, 2), dtype=np.float32, overwrite=True
|
||||
"data/keypoint",
|
||||
shape=(num_frames, 9, 2),
|
||||
chunks=(num_frames, 9, 2),
|
||||
dtype=np.float32,
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True
|
||||
"meta/episode_ends",
|
||||
shape=(num_episodes,),
|
||||
chunks=(num_episodes,),
|
||||
dtype=np.int32,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
zarr_data["data/action"][:] = np.random.randn(num_frames, 1)
|
||||
zarr_data["data/img"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8)
|
||||
zarr_data["data/img"][:] = np.random.randint(
|
||||
0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8
|
||||
)
|
||||
zarr_data["data/n_contacts"][:] = np.random.randn(num_frames, 2)
|
||||
zarr_data["data/state"][:] = np.random.randn(num_frames, 5)
|
||||
zarr_data["data/keypoint"][:] = np.random.randn(num_frames, 9, 2)
|
||||
@@ -93,7 +115,11 @@ def _mock_download_raw_umi(raw_dir, num_frames=4, num_episodes=3):
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/robot0_eef_pos", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True
|
||||
"data/robot0_eef_pos",
|
||||
shape=(num_frames, 5),
|
||||
chunks=(num_frames, 5),
|
||||
dtype=np.float32,
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"data/robot0_eef_rot_axis_angle",
|
||||
@@ -110,10 +136,16 @@ def _mock_download_raw_umi(raw_dir, num_frames=4, num_episodes=3):
|
||||
overwrite=True,
|
||||
)
|
||||
zarr_data.create_dataset(
|
||||
"meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True
|
||||
"meta/episode_ends",
|
||||
shape=(num_episodes,),
|
||||
chunks=(num_episodes,),
|
||||
dtype=np.int32,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
zarr_data["data/camera0_rgb"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8)
|
||||
zarr_data["data/camera0_rgb"][:] = np.random.randint(
|
||||
0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8
|
||||
)
|
||||
zarr_data["data/robot0_demo_end_pose"][:] = np.random.randn(num_frames, 5)
|
||||
zarr_data["data/robot0_demo_start_pose"][:] = np.random.randn(num_frames, 5)
|
||||
zarr_data["data/robot0_eef_pos"][:] = np.random.randn(num_frames, 5)
|
||||
@@ -129,7 +161,9 @@ def _mock_download_raw_xarm(raw_dir, num_frames=4):
|
||||
|
||||
dataset_dict = {
|
||||
"observations": {
|
||||
"rgb": np.random.randint(0, 255, size=(num_frames, 3, 84, 84), dtype=np.uint8),
|
||||
"rgb": np.random.randint(
|
||||
0, 255, size=(num_frames, 3, 84, 84), dtype=np.uint8
|
||||
),
|
||||
"state": np.random.randn(num_frames, 4),
|
||||
},
|
||||
"actions": np.random.randn(num_frames, 3),
|
||||
@@ -151,13 +185,24 @@ def _mock_download_raw_aloha(raw_dir, num_frames=6, num_episodes=3):
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
path_h5 = raw_dir / f"episode_{ep_idx}.hdf5"
|
||||
with h5py.File(str(path_h5), "w") as f:
|
||||
f.create_dataset("action", data=np.random.randn(num_frames // num_episodes, 14))
|
||||
f.create_dataset("observations/qpos", data=np.random.randn(num_frames // num_episodes, 14))
|
||||
f.create_dataset("observations/qvel", data=np.random.randn(num_frames // num_episodes, 14))
|
||||
f.create_dataset(
|
||||
"action", data=np.random.randn(num_frames // num_episodes, 14)
|
||||
)
|
||||
f.create_dataset(
|
||||
"observations/qpos",
|
||||
data=np.random.randn(num_frames // num_episodes, 14),
|
||||
)
|
||||
f.create_dataset(
|
||||
"observations/qvel",
|
||||
data=np.random.randn(num_frames // num_episodes, 14),
|
||||
)
|
||||
f.create_dataset(
|
||||
"observations/images/top",
|
||||
data=np.random.randint(
|
||||
0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8
|
||||
0,
|
||||
255,
|
||||
size=(num_frames // num_episodes, 480, 640, 3),
|
||||
dtype=np.uint8,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -191,7 +236,12 @@ def _mock_download_raw_dora(raw_dir, num_frames=6, num_episodes=3, fps=30):
|
||||
action = np.random.randn(21).tolist()
|
||||
state = np.random.randn(21).tolist()
|
||||
ep_idx = episode_indices_mapping[i]
|
||||
frame = [{"path": f"videos/{cam_key}_episode_{ep_idx:06d}.mp4", "timestamp": frame_idx / fps}]
|
||||
frame = [
|
||||
{
|
||||
"path": f"videos/{cam_key}_episode_{ep_idx:06d}.mp4",
|
||||
"timestamp": frame_idx / fps,
|
||||
}
|
||||
]
|
||||
timestamps.append(t_utc)
|
||||
actions.append(action)
|
||||
states.append(state)
|
||||
@@ -204,7 +254,9 @@ def _mock_download_raw_dora(raw_dir, num_frames=6, num_episodes=3, fps=30):
|
||||
|
||||
# write fake mp4 file for each episode
|
||||
for ep_idx in range(num_episodes):
|
||||
imgs_array = np.random.randint(0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8)
|
||||
imgs_array = np.random.randint(
|
||||
0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8
|
||||
)
|
||||
|
||||
tmp_imgs_dir = raw_dir / "tmp_images"
|
||||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
@@ -263,7 +315,9 @@ def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir):
|
||||
],
|
||||
)
|
||||
@require_package_arg
|
||||
def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_id, make_test_data):
|
||||
def test_push_dataset_to_hub_format(
|
||||
required_packages, tmpdir, raw_format, repo_id, make_test_data
|
||||
):
|
||||
num_episodes = 3
|
||||
tmpdir = Path(tmpdir)
|
||||
|
||||
@@ -315,7 +369,10 @@ def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_
|
||||
== lerobot_dataset.hf_dataset["episode_index"][:num_frames]
|
||||
)
|
||||
for k in ["from", "to"]:
|
||||
assert torch.equal(test_dataset.episode_data_index[k], lerobot_dataset.episode_data_index[k][:1])
|
||||
assert torch.equal(
|
||||
test_dataset.episode_data_index[k],
|
||||
lerobot_dataset.episode_data_index[k][:1],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -359,8 +416,12 @@ def test_push_dataset_to_hub_pusht_backward_compatibility(tmpdir, raw_format, re
|
||||
assert item1.keys() == item2.keys(), "Keys mismatch"
|
||||
|
||||
for key in item1:
|
||||
if isinstance(item1[key], torch.Tensor) and isinstance(item2[key], torch.Tensor):
|
||||
assert torch.equal(item1[key], item2[key]), f"Mismatch found in key: {key}"
|
||||
if isinstance(item1[key], torch.Tensor) and isinstance(
|
||||
item2[key], torch.Tensor
|
||||
):
|
||||
assert torch.equal(
|
||||
item1[key], item2[key]
|
||||
), f"Mismatch found in key: {key}"
|
||||
else:
|
||||
assert item1[key] == item2[key], f"Mismatch found in key: {key}"
|
||||
|
||||
|
||||
@@ -29,8 +29,16 @@ import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.robot_devices.robots.manipulator import ManipulatorRobot
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
from tests.utils import TEST_ROBOT_TYPES, make_robot, mock_calibration_dir, require_robot
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceNotConnectedError,
|
||||
)
|
||||
from tests.utils import (
|
||||
TEST_ROBOT_TYPES,
|
||||
make_robot,
|
||||
mock_calibration_dir,
|
||||
require_robot,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
|
||||
@@ -104,7 +112,9 @@ def test_robot(tmpdir, request, robot_type, mock):
|
||||
assert "observation.state" in observation
|
||||
assert isinstance(observation["observation.state"], torch.Tensor)
|
||||
assert observation["observation.state"].ndim == 1
|
||||
dim_state = sum(len(robot.follower_arms[name].motors) for name in robot.follower_arms)
|
||||
dim_state = sum(
|
||||
len(robot.follower_arms[name].motors) for name in robot.follower_arms
|
||||
)
|
||||
assert observation["observation.state"].shape[0] == dim_state
|
||||
# Cameras
|
||||
for name in robot.cameras:
|
||||
@@ -115,7 +125,9 @@ def test_robot(tmpdir, request, robot_type, mock):
|
||||
assert "action" in action
|
||||
assert isinstance(action["action"], torch.Tensor)
|
||||
assert action["action"].ndim == 1
|
||||
dim_action = sum(len(robot.follower_arms[name].motors) for name in robot.follower_arms)
|
||||
dim_action = sum(
|
||||
len(robot.follower_arms[name].motors) for name in robot.follower_arms
|
||||
)
|
||||
assert action["action"].shape[0] == dim_action
|
||||
# TODO(rcadene): test if observation and action data are returned as expected
|
||||
|
||||
|
||||
@@ -15,7 +15,9 @@
|
||||
# limitations under the License.
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
)
|
||||
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
|
||||
@@ -9,7 +9,9 @@ from hydra import compose, initialize_config_dir
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
|
||||
ClassifierConfig,
|
||||
)
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
from lerobot.scripts.train_hilserl_classifier import (
|
||||
create_balanced_sampler,
|
||||
@@ -34,7 +36,9 @@ class MockDataset(Dataset):
|
||||
|
||||
def make_dummy_model():
|
||||
model_config = ClassifierConfig(
|
||||
num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel", num_cameras=1
|
||||
num_classes=2,
|
||||
model_name="hf-tiny-model-private/tiny-random-ResNetModel",
|
||||
num_cameras=1,
|
||||
)
|
||||
model = Classifier(config=model_config)
|
||||
return model
|
||||
@@ -65,7 +69,9 @@ def test_create_balanced_sampler():
|
||||
labels = [item["label"] for item in data]
|
||||
class_counts = torch.tensor([labels.count(0), labels.count(1)], dtype=torch.float32)
|
||||
class_weights = 1.0 / class_counts
|
||||
expected_weights = torch.tensor([class_weights[label] for label in labels], dtype=torch.float32)
|
||||
expected_weights = torch.tensor(
|
||||
[class_weights[label] for label in labels], dtype=torch.float32
|
||||
)
|
||||
|
||||
# Test that the weights are correct
|
||||
assert torch.allclose(weights, expected_weights)
|
||||
@@ -149,7 +155,9 @@ def test_validate():
|
||||
|
||||
def test_train_epoch_multiple_cameras():
|
||||
model_config = ClassifierConfig(
|
||||
num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel", num_cameras=2
|
||||
num_classes=2,
|
||||
model_name="hf-tiny-model-private/tiny-random-ResNetModel",
|
||||
num_cameras=2,
|
||||
)
|
||||
model = Classifier(config=model_config)
|
||||
|
||||
@@ -216,10 +224,16 @@ def test_resume_function(
|
||||
):
|
||||
# Initialize Hydra
|
||||
test_file_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
config_dir = os.path.abspath(os.path.join(test_file_dir, "..", "lerobot", "configs", "policy"))
|
||||
assert os.path.exists(config_dir), f"Config directory does not exist at {config_dir}"
|
||||
config_dir = os.path.abspath(
|
||||
os.path.join(test_file_dir, "..", "lerobot", "configs", "policy")
|
||||
)
|
||||
assert os.path.exists(
|
||||
config_dir
|
||||
), f"Config directory does not exist at {config_dir}"
|
||||
|
||||
with initialize_config_dir(config_dir=config_dir, job_name="test_app", version_base="1.2"):
|
||||
with initialize_config_dir(
|
||||
config_dir=config_dir, job_name="test_app", version_base="1.2"
|
||||
):
|
||||
cfg = compose(
|
||||
config_name="hilserl_classifier",
|
||||
overrides=[
|
||||
@@ -244,7 +258,9 @@ def test_resume_function(
|
||||
mock_init_hydra_config.return_value = cfg
|
||||
|
||||
# Mock dataset
|
||||
dataset = MockDataset([{"image": torch.rand(3, 224, 224), "label": i % 2} for i in range(10)])
|
||||
dataset = MockDataset(
|
||||
[{"image": torch.rand(3, 224, 224), "label": i % 2} for i in range(10)]
|
||||
)
|
||||
mock_dataset.return_value = dataset
|
||||
|
||||
# Mock checkpoint handling
|
||||
|
||||
@@ -7,7 +7,9 @@ import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
|
||||
@@ -26,7 +26,9 @@ import torch
|
||||
from lerobot import available_cameras, available_motors, available_robots
|
||||
from lerobot.common.robot_devices.cameras.utils import Camera
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot as make_robot_from_cfg
|
||||
from lerobot.common.robot_devices.robots.factory import (
|
||||
make_robot as make_robot_from_cfg,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.utils.import_utils import is_package_available
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
@@ -52,9 +54,13 @@ for motor_type in available_motors:
|
||||
|
||||
# Camera indices used for connecting physical cameras
|
||||
OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0))
|
||||
INTELREALSENSE_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_INTELREALSENSE_CAMERA_INDEX", 128422271614))
|
||||
INTELREALSENSE_CAMERA_INDEX = int(
|
||||
os.environ.get("LEROBOT_TEST_INTELREALSENSE_CAMERA_INDEX", 128422271614)
|
||||
)
|
||||
|
||||
DYNAMIXEL_PORT = os.environ.get("LEROBOT_TEST_DYNAMIXEL_PORT", "/dev/tty.usbmodem575E0032081")
|
||||
DYNAMIXEL_PORT = os.environ.get(
|
||||
"LEROBOT_TEST_DYNAMIXEL_PORT", "/dev/tty.usbmodem575E0032081"
|
||||
)
|
||||
DYNAMIXEL_MOTORS = {
|
||||
"shoulder_pan": [1, "xl430-w250"],
|
||||
"shoulder_lift": [2, "xl430-w250"],
|
||||
@@ -64,7 +70,9 @@ DYNAMIXEL_MOTORS = {
|
||||
"gripper": [6, "xl330-m288"],
|
||||
}
|
||||
|
||||
FEETECH_PORT = os.environ.get("LEROBOT_TEST_FEETECH_PORT", "/dev/tty.usbmodem585A0080971")
|
||||
FEETECH_PORT = os.environ.get(
|
||||
"LEROBOT_TEST_FEETECH_PORT", "/dev/tty.usbmodem585A0080971"
|
||||
)
|
||||
FEETECH_MOTORS = {
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
@@ -163,9 +171,13 @@ def require_package_arg(func):
|
||||
if "required_packages" in arg_names:
|
||||
# Get the index of 'required_packages' and retrieve the value from args
|
||||
index = arg_names.index("required_packages")
|
||||
required_packages = args[index] if len(args) > index else kwargs.get("required_packages")
|
||||
required_packages = (
|
||||
args[index] if len(args) > index else kwargs.get("required_packages")
|
||||
)
|
||||
else:
|
||||
raise ValueError("Function does not have 'required_packages' as an argument.")
|
||||
raise ValueError(
|
||||
"Function does not have 'required_packages' as an argument."
|
||||
)
|
||||
|
||||
if required_packages is None:
|
||||
return func(*args, **kwargs)
|
||||
@@ -222,11 +234,17 @@ def require_robot(func):
|
||||
mock = kwargs.get("mock")
|
||||
|
||||
if robot_type is None:
|
||||
raise ValueError("The 'robot_type' must be an argument of the test function.")
|
||||
raise ValueError(
|
||||
"The 'robot_type' must be an argument of the test function."
|
||||
)
|
||||
if request is None:
|
||||
raise ValueError("The 'request' fixture must be an argument of the test function.")
|
||||
raise ValueError(
|
||||
"The 'request' fixture must be an argument of the test function."
|
||||
)
|
||||
if mock is None:
|
||||
raise ValueError("The 'mock' variable must be an argument of the test function.")
|
||||
raise ValueError(
|
||||
"The 'mock' variable must be an argument of the test function."
|
||||
)
|
||||
|
||||
# Run test with a real robot. Skip test if robot connection fails.
|
||||
if not mock and not request.getfixturevalue("is_robot_available"):
|
||||
@@ -246,11 +264,17 @@ def require_camera(func):
|
||||
mock = kwargs.get("mock")
|
||||
|
||||
if request is None:
|
||||
raise ValueError("The 'request' fixture must be an argument of the test function.")
|
||||
raise ValueError(
|
||||
"The 'request' fixture must be an argument of the test function."
|
||||
)
|
||||
if camera_type is None:
|
||||
raise ValueError("The 'camera_type' must be an argument of the test function.")
|
||||
raise ValueError(
|
||||
"The 'camera_type' must be an argument of the test function."
|
||||
)
|
||||
if mock is None:
|
||||
raise ValueError("The 'mock' variable must be an argument of the test function.")
|
||||
raise ValueError(
|
||||
"The 'mock' variable must be an argument of the test function."
|
||||
)
|
||||
|
||||
if not mock and not request.getfixturevalue("is_camera_available"):
|
||||
pytest.skip(f"A {camera_type} camera is not available.")
|
||||
@@ -269,11 +293,17 @@ def require_motor(func):
|
||||
mock = kwargs.get("mock")
|
||||
|
||||
if request is None:
|
||||
raise ValueError("The 'request' fixture must be an argument of the test function.")
|
||||
raise ValueError(
|
||||
"The 'request' fixture must be an argument of the test function."
|
||||
)
|
||||
if motor_type is None:
|
||||
raise ValueError("The 'motor_type' must be an argument of the test function.")
|
||||
raise ValueError(
|
||||
"The 'motor_type' must be an argument of the test function."
|
||||
)
|
||||
if mock is None:
|
||||
raise ValueError("The 'mock' variable must be an argument of the test function.")
|
||||
raise ValueError(
|
||||
"The 'mock' variable must be an argument of the test function."
|
||||
)
|
||||
|
||||
if not mock and not request.getfixturevalue("is_motor_available"):
|
||||
pytest.skip(f"A {motor_type} motor is not available.")
|
||||
@@ -292,7 +322,14 @@ def mock_calibration_dir(calibration_dir):
|
||||
"start_pos": [1442, 843, 2166, 2849, 1988, 1835],
|
||||
"end_pos": [2440, 1869, -1106, -1848, -926, 3235],
|
||||
"calib_mode": ["DEGREE", "DEGREE", "DEGREE", "DEGREE", "DEGREE", "LINEAR"],
|
||||
"motor_names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
||||
"motor_names": [
|
||||
"shoulder_pan",
|
||||
"shoulder_lift",
|
||||
"elbow_flex",
|
||||
"wrist_flex",
|
||||
"wrist_roll",
|
||||
"gripper",
|
||||
],
|
||||
}
|
||||
Path(str(calibration_dir)).mkdir(parents=True, exist_ok=True)
|
||||
with open(calibration_dir / "main_follower.json", "w") as f:
|
||||
@@ -309,7 +346,9 @@ def mock_calibration_dir(calibration_dir):
|
||||
json.dump(example_calib, f)
|
||||
|
||||
|
||||
def make_robot(robot_type: str, overrides: list[str] | None = None, mock=False) -> Robot:
|
||||
def make_robot(
|
||||
robot_type: str, overrides: list[str] | None = None, mock=False
|
||||
) -> Robot:
|
||||
if mock:
|
||||
overrides = [] if overrides is None else copy(overrides)
|
||||
|
||||
@@ -359,7 +398,9 @@ def make_camera(camera_type, **kwargs) -> Camera:
|
||||
return OpenCVCamera(camera_index, **kwargs)
|
||||
|
||||
elif camera_type == "intelrealsense":
|
||||
from lerobot.common.robot_devices.cameras.intelrealsense import IntelRealSenseCamera
|
||||
from lerobot.common.robot_devices.cameras.intelrealsense import (
|
||||
IntelRealSenseCamera,
|
||||
)
|
||||
|
||||
camera_index = kwargs.pop("camera_index", INTELREALSENSE_CAMERA_INDEX)
|
||||
return IntelRealSenseCamera(camera_index, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user