Package folder structure (#1417)
* Move files * Replace imports & paths * Update relative paths * Update doc symlinks * Update instructions paths * Fix imports * Update grpc files * Update more instructions * Downgrade grpc-tools * Update manifest * Update more paths * Update config paths * Update CI paths * Update bandit exclusions * Remove walkthrough section
This commit is contained in:
@@ -31,7 +31,7 @@ from pathlib import Path
|
||||
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
|
||||
|
||||
@@ -18,14 +18,14 @@ from pathlib import Path
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.transforms import (
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.transforms import (
|
||||
ImageTransformConfig,
|
||||
ImageTransforms,
|
||||
ImageTransformsConfig,
|
||||
make_transform_from_config,
|
||||
)
|
||||
from lerobot.common.utils.random_utils import seeded_context
|
||||
from lerobot.utils.random_utils import seeded_context
|
||||
|
||||
ARTIFACT_DIR = Path("tests/artifacts/image_transforms")
|
||||
DATASET_REPO_ID = "lerobot/aloha_static_cups_open"
|
||||
|
||||
@@ -19,12 +19,12 @@ from pathlib import Path
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.common.policies.factory import make_policy, make_policy_config
|
||||
from lerobot.common.utils.random_utils import set_seed
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies.factory import make_policy, make_policy_config
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
|
||||
|
||||
def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
|
||||
@@ -24,9 +24,9 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.common.cameras.configs import Cv2Rotation
|
||||
from lerobot.common.cameras.opencv import OpenCVCamera, OpenCVCameraConfig
|
||||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.cameras.configs import Cv2Rotation
|
||||
from lerobot.cameras.opencv import OpenCVCamera, OpenCVCameraConfig
|
||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
# NOTE(Steven): more tests + assertions?
|
||||
TEST_ARTIFACTS_DIR = Path(__file__).parent.parent / "artifacts" / "cameras"
|
||||
|
||||
@@ -25,12 +25,12 @@ from unittest.mock import patch
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.common.cameras.configs import Cv2Rotation
|
||||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.cameras.configs import Cv2Rotation
|
||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
pytest.importorskip("pyrealsense2")
|
||||
|
||||
from lerobot.common.cameras.realsense import RealSenseCamera, RealSenseCameraConfig
|
||||
from lerobot.cameras.realsense import RealSenseCamera, RealSenseCameraConfig
|
||||
|
||||
TEST_ARTIFACTS_DIR = Path(__file__).parent.parent / "artifacts" / "cameras"
|
||||
BAG_FILE_PATH = TEST_ARTIFACTS_DIR / "test_rs.bag"
|
||||
|
||||
@@ -5,15 +5,15 @@ from typing import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.envs.configs import EnvConfig
|
||||
from lerobot.configs.parser import PluginLoadError, load_plugin, parse_plugin_args, wrap
|
||||
from lerobot.envs.configs import EnvConfig
|
||||
|
||||
|
||||
def create_plugin_code(*, base_class: str = "EnvConfig", plugin_name: str = "test_env") -> str:
|
||||
"""Creates a dummy plugin module that implements its own EnvConfig subclass."""
|
||||
return f"""
|
||||
from dataclasses import dataclass
|
||||
from lerobot.common.envs.configs import {base_class}
|
||||
from lerobot.envs.configs import {base_class}
|
||||
|
||||
@{base_class}.register_subclass("{plugin_name}")
|
||||
@dataclass
|
||||
|
||||
@@ -18,7 +18,7 @@ from unittest.mock import patch
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.common.datasets.compute_stats import (
|
||||
from lerobot.datasets.compute_stats import (
|
||||
_assert_type_and_shape,
|
||||
aggregate_feature_stats,
|
||||
aggregate_stats,
|
||||
@@ -61,7 +61,7 @@ def test_sample_indices():
|
||||
assert len(indices) == estimate_num_samples(10)
|
||||
|
||||
|
||||
@patch("lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy)
|
||||
@patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy)
|
||||
def test_sample_images(mock_load):
|
||||
image_paths = [f"image_{i}.jpg" for i in range(100)]
|
||||
images = sample_images(image_paths)
|
||||
@@ -144,9 +144,7 @@ def test_compute_episode_stats():
|
||||
"observation.state": {"dtype": "numeric"},
|
||||
}
|
||||
|
||||
with patch(
|
||||
"lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy
|
||||
):
|
||||
with patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy):
|
||||
stats = compute_episode_stats(episode_data, features)
|
||||
|
||||
assert "observation.image" in stats and "observation.state" in stats
|
||||
|
||||
@@ -28,21 +28,21 @@ from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.image_writer import image_array_to_pil_image
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.image_writer import image_array_to_pil_image
|
||||
from lerobot.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
MultiLeRobotDataset,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
from lerobot.datasets.utils import (
|
||||
create_branch,
|
||||
flatten_dict,
|
||||
unflatten_dict,
|
||||
)
|
||||
from lerobot.common.envs.factory import make_env_config
|
||||
from lerobot.common.policies.factory import make_policy_config
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.envs.factory import make_env_config
|
||||
from lerobot.policies.factory import make_policy_config
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||
from tests.utils import require_x86_64_kernel
|
||||
|
||||
@@ -558,7 +558,7 @@ def test_create_branch():
|
||||
|
||||
def test_dataset_feature_with_forward_slash_raises_error():
|
||||
# make sure dir does not exist
|
||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
from lerobot.constants import HF_LEROBOT_HOME
|
||||
|
||||
dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash"
|
||||
# make sure does not exist
|
||||
|
||||
@@ -19,7 +19,7 @@ import pyarrow.compute as pc
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
from lerobot.datasets.utils import (
|
||||
check_delta_timestamps,
|
||||
check_timestamps_sync,
|
||||
get_delta_indices,
|
||||
|
||||
@@ -21,7 +21,7 @@ from safetensors.torch import load_file
|
||||
from torchvision.transforms import v2
|
||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
|
||||
from lerobot.common.datasets.transforms import (
|
||||
from lerobot.datasets.transforms import (
|
||||
ImageTransformConfig,
|
||||
ImageTransforms,
|
||||
ImageTransformsConfig,
|
||||
@@ -29,11 +29,11 @@ from lerobot.common.datasets.transforms import (
|
||||
SharpnessJitter,
|
||||
make_transform_from_config,
|
||||
)
|
||||
from lerobot.common.utils.random_utils import seeded_context
|
||||
from lerobot.scripts.visualize_image_transforms import (
|
||||
save_all_transforms,
|
||||
save_each_transform,
|
||||
)
|
||||
from lerobot.utils.random_utils import seeded_context
|
||||
from tests.artifacts.image_transforms.save_image_transforms_to_safetensors import ARTIFACT_DIR
|
||||
from tests.utils import require_x86_64_kernel
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.common.datasets.image_writer import (
|
||||
from lerobot.datasets.image_writer import (
|
||||
AsyncImageWriter,
|
||||
image_array_to_pil_image,
|
||||
safe_stop_image_writer,
|
||||
|
||||
@@ -20,7 +20,7 @@ import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
|
||||
from lerobot.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
|
||||
|
||||
# Some constants for OnlineBuffer tests.
|
||||
data_key = "data"
|
||||
|
||||
@@ -15,9 +15,9 @@
|
||||
# limitations under the License.
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.common.datasets.utils import (
|
||||
from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
|
||||
|
||||
@@ -18,8 +18,8 @@ import torch
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import DatasetCard
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.common.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch
|
||||
from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch
|
||||
|
||||
|
||||
def test_default_parameters():
|
||||
|
||||
@@ -21,8 +21,8 @@ import torch
|
||||
from gymnasium.utils.env_checker import check_env
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.envs.factory import make_env, make_env_config
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.envs.factory import make_env, make_env_config
|
||||
from lerobot.envs.utils import preprocess_observation
|
||||
from tests.utils import require_env
|
||||
|
||||
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]
|
||||
|
||||
2
tests/fixtures/constants.py
vendored
2
tests/fixtures/constants.py
vendored
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
from lerobot.constants import HF_LEROBOT_HOME
|
||||
|
||||
LEROBOT_TEST_DIR = HF_LEROBOT_HOME / "_testing"
|
||||
DUMMY_REPO_ID = "dummy/repo"
|
||||
|
||||
18
tests/fixtures/dataset_factories.py
vendored
18
tests/fixtures/dataset_factories.py
vendored
@@ -23,8 +23,8 @@ import PIL.Image
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import (
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_FEATURES,
|
||||
DEFAULT_PARQUET_PATH,
|
||||
@@ -351,10 +351,8 @@ def lerobot_dataset_metadata_factory(
|
||||
episodes=episodes,
|
||||
)
|
||||
with (
|
||||
patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch,
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
|
||||
) as mock_snapshot_download_patch,
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download_patch,
|
||||
):
|
||||
mock_get_safe_version_patch.side_effect = lambda repo_id, version: version
|
||||
mock_snapshot_download_patch.side_effect = mock_snapshot_download
|
||||
@@ -428,11 +426,9 @@ def lerobot_dataset_factory(
|
||||
episodes=episode_dicts,
|
||||
)
|
||||
with (
|
||||
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
|
||||
patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch,
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
|
||||
) as mock_snapshot_download_patch,
|
||||
patch("lerobot.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download_patch,
|
||||
):
|
||||
mock_metadata_patch.return_value = mock_metadata
|
||||
mock_get_safe_version_patch.side_effect = lambda repo_id, version: version
|
||||
|
||||
2
tests/fixtures/files.py
vendored
2
tests/fixtures/files.py
vendored
@@ -20,7 +20,7 @@ import pyarrow.compute as pc
|
||||
import pyarrow.parquet as pq
|
||||
import pytest
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
from lerobot.datasets.utils import (
|
||||
EPISODES_PATH,
|
||||
EPISODES_STATS_PATH,
|
||||
INFO_PATH,
|
||||
|
||||
2
tests/fixtures/hub.py
vendored
2
tests/fixtures/hub.py
vendored
@@ -17,7 +17,7 @@ import datasets
|
||||
import pytest
|
||||
from huggingface_hub.utils import filter_repo_objects
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
from lerobot.datasets.utils import (
|
||||
EPISODES_PATH,
|
||||
EPISODES_STATS_PATH,
|
||||
INFO_PATH,
|
||||
|
||||
4
tests/fixtures/optimizers.py
vendored
4
tests/fixtures/optimizers.py
vendored
@@ -14,8 +14,8 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamConfig
|
||||
from lerobot.common.optim.schedulers import VQBeTSchedulerConfig
|
||||
from lerobot.optim.optimizers import AdamConfig
|
||||
from lerobot.optim.schedulers import VQBeTSchedulerConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -5,7 +5,7 @@ import dynamixel_sdk as dxl
|
||||
import serial
|
||||
from mock_serial.mock_serial import MockSerial
|
||||
|
||||
from lerobot.common.motors.dynamixel.dynamixel import _split_into_byte_chunks
|
||||
from lerobot.motors.dynamixel.dynamixel import _split_into_byte_chunks
|
||||
|
||||
from .mock_serial_patch import WaitableStub
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ import scservo_sdk as scs
|
||||
import serial
|
||||
from mock_serial import MockSerial
|
||||
|
||||
from lerobot.common.motors.feetech.feetech import _split_into_byte_chunks, patch_setPacketTimeout
|
||||
from lerobot.motors.feetech.feetech import _split_into_byte_chunks, patch_setPacketTimeout
|
||||
|
||||
from .mock_serial_patch import WaitableStub
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# ruff: noqa: N802
|
||||
|
||||
from lerobot.common.motors.motors_bus import (
|
||||
from lerobot.motors.motors_bus import (
|
||||
Motor,
|
||||
MotorsBus,
|
||||
)
|
||||
|
||||
@@ -3,9 +3,9 @@ from dataclasses import dataclass, field
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
from lerobot.common.cameras import CameraConfig, make_cameras_from_configs
|
||||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.common.robots import Robot, RobotConfig
|
||||
from lerobot.cameras import CameraConfig, make_cameras_from_configs
|
||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.robots import Robot, RobotConfig
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("mock_robot")
|
||||
|
||||
@@ -3,8 +3,8 @@ from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.common.teleoperators import Teleoperator, TeleoperatorConfig
|
||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("mock_teleop")
|
||||
|
||||
@@ -5,10 +5,10 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.common.motors.dynamixel import MODEL_NUMBER_TABLE, DynamixelMotorsBus
|
||||
from lerobot.common.motors.dynamixel.tables import X_SERIES_CONTROL_TABLE
|
||||
from lerobot.common.utils.encoding_utils import encode_twos_complement
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.motors.dynamixel import MODEL_NUMBER_TABLE, DynamixelMotorsBus
|
||||
from lerobot.motors.dynamixel.tables import X_SERIES_CONTROL_TABLE
|
||||
from lerobot.utils.encoding_utils import encode_twos_complement
|
||||
|
||||
try:
|
||||
import dynamixel_sdk as dxl
|
||||
@@ -389,7 +389,7 @@ def test_record_ranges_of_motion(mock_motors, dummy_motors):
|
||||
read_pos_stub = mock_motors.build_sequential_sync_read_stub(
|
||||
*X_SERIES_CONTROL_TABLE["Present_Position"], positions
|
||||
)
|
||||
with patch("lerobot.common.motors.motors_bus.enter_pressed", side_effect=[False, True]):
|
||||
with patch("lerobot.motors.motors_bus.enter_pressed", side_effect=[False, True]):
|
||||
bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||
bus.connect(handshake=False)
|
||||
|
||||
|
||||
@@ -5,10 +5,10 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.common.motors.feetech import MODEL_NUMBER, MODEL_NUMBER_TABLE, FeetechMotorsBus
|
||||
from lerobot.common.motors.feetech.tables import STS_SMS_SERIES_CONTROL_TABLE
|
||||
from lerobot.common.utils.encoding_utils import encode_sign_magnitude
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.motors.feetech import MODEL_NUMBER, MODEL_NUMBER_TABLE, FeetechMotorsBus
|
||||
from lerobot.motors.feetech.tables import STS_SMS_SERIES_CONTROL_TABLE
|
||||
from lerobot.utils.encoding_utils import encode_sign_magnitude
|
||||
|
||||
try:
|
||||
import scservo_sdk as scs
|
||||
@@ -432,7 +432,7 @@ def test_record_ranges_of_motion(mock_motors, dummy_motors):
|
||||
stub = mock_motors.build_sequential_sync_read_stub(
|
||||
*STS_SMS_SERIES_CONTROL_TABLE["Present_Position"], positions
|
||||
)
|
||||
with patch("lerobot.common.motors.motors_bus.enter_pressed", side_effect=[False, True]):
|
||||
with patch("lerobot.motors.motors_bus.enter_pressed", side_effect=[False, True]):
|
||||
bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors)
|
||||
bus.connect(handshake=False)
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.motors.motors_bus import (
|
||||
from lerobot.motors.motors_bus import (
|
||||
Motor,
|
||||
MotorNormMode,
|
||||
assert_same_address,
|
||||
|
||||
@@ -14,11 +14,11 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.constants import (
|
||||
from lerobot.constants import (
|
||||
OPTIMIZER_PARAM_GROUPS,
|
||||
OPTIMIZER_STATE,
|
||||
)
|
||||
from lerobot.common.optim.optimizers import (
|
||||
from lerobot.optim.optimizers import (
|
||||
AdamConfig,
|
||||
AdamWConfig,
|
||||
MultiAdamConfig,
|
||||
|
||||
@@ -13,8 +13,8 @@
|
||||
# limitations under the License.
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
from lerobot.common.constants import SCHEDULER_STATE
|
||||
from lerobot.common.optim.schedulers import (
|
||||
from lerobot.constants import SCHEDULER_STATE
|
||||
from lerobot.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
DiffuserSchedulerConfig,
|
||||
VQBeTSchedulerConfig,
|
||||
|
||||
@@ -16,9 +16,9 @@
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.common.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.common.policies.sac.reward_model.modeling_classifier import ClassifierOutput
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import ClassifierOutput
|
||||
from tests.utils import require_package
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ def test_classifier_output():
|
||||
|
||||
@require_package("transformers")
|
||||
def test_binary_classifier_with_default_params():
|
||||
from lerobot.common.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
|
||||
config = RewardClassifierConfig()
|
||||
config.input_features = {
|
||||
@@ -78,7 +78,7 @@ def test_binary_classifier_with_default_params():
|
||||
|
||||
@require_package("transformers")
|
||||
def test_multiclass_classifier():
|
||||
from lerobot.common.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
|
||||
num_classes = 5
|
||||
config = RewardClassifierConfig()
|
||||
@@ -117,7 +117,7 @@ def test_multiclass_classifier():
|
||||
|
||||
@require_package("transformers")
|
||||
def test_default_device():
|
||||
from lerobot.common.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
|
||||
config = RewardClassifierConfig()
|
||||
assert config.device == "cpu"
|
||||
@@ -129,7 +129,7 @@ def test_default_device():
|
||||
|
||||
@require_package("transformers")
|
||||
def test_explicit_device_setup():
|
||||
from lerobot.common.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
|
||||
config = RewardClassifierConfig(device="cpu")
|
||||
assert config.device == "cpu"
|
||||
|
||||
@@ -24,23 +24,23 @@ from packaging import version
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from lerobot import available_policies
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import cycle, dataset_to_policy_features
|
||||
from lerobot.common.envs.factory import make_env, make_env_config
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.common.policies.act.modeling_act import ACTTemporalEnsembler
|
||||
from lerobot.common.policies.factory import (
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.utils import cycle, dataset_to_policy_features
|
||||
from lerobot.envs.factory import make_env, make_env_config
|
||||
from lerobot.envs.utils import preprocess_observation
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies.act.modeling_act import ACTTemporalEnsembler
|
||||
from lerobot.policies.factory import (
|
||||
get_policy_class,
|
||||
make_policy,
|
||||
make_policy_config,
|
||||
)
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.utils.random_utils import seeded_context
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.random_utils import seeded_context
|
||||
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
|
||||
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
|
||||
|
||||
|
||||
@@ -16,7 +16,8 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.policies.sac.configuration_sac import (
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.policies.sac.configuration_sac import (
|
||||
ActorLearnerConfig,
|
||||
ActorNetworkConfig,
|
||||
ConcurrencyConfig,
|
||||
@@ -24,7 +25,6 @@ from lerobot.common.policies.sac.configuration_sac import (
|
||||
PolicyConfig,
|
||||
SACConfig,
|
||||
)
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
|
||||
def test_sac_config_default_initialization():
|
||||
|
||||
@@ -20,10 +20,10 @@ import pytest
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.common.policies.sac.modeling_sac import MLP, SACPolicy
|
||||
from lerobot.common.utils.random_utils import seeded_context, set_seed
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.modeling_sac import MLP, SACPolicy
|
||||
from lerobot.utils.random_utils import seeded_context, set_seed
|
||||
|
||||
try:
|
||||
import transformers # noqa: F401
|
||||
|
||||
@@ -21,14 +21,14 @@ import pytest
|
||||
import torch
|
||||
from torch.multiprocessing import Event, Queue
|
||||
|
||||
from lerobot.common.utils.transition import Transition
|
||||
from lerobot.utils.transition import Transition
|
||||
from tests.utils import require_package
|
||||
|
||||
|
||||
def create_learner_service_stub():
|
||||
import grpc
|
||||
|
||||
from lerobot.common.transport import services_pb2, services_pb2_grpc
|
||||
from lerobot.transport import services_pb2, services_pb2_grpc
|
||||
|
||||
class MockLearnerService(services_pb2_grpc.LearnerServiceServicer):
|
||||
def __init__(self):
|
||||
@@ -101,8 +101,8 @@ def test_establish_learner_connection_failure():
|
||||
|
||||
@require_package("grpc")
|
||||
def test_push_transitions_to_transport_queue():
|
||||
from lerobot.common.transport.utils import bytes_to_transitions
|
||||
from lerobot.scripts.rl.actor import push_transitions_to_transport_queue
|
||||
from lerobot.transport.utils import bytes_to_transitions
|
||||
from tests.transport.test_transport_utils import assert_transitions_equal
|
||||
|
||||
"""Test pushing transitions to transport queue."""
|
||||
@@ -169,8 +169,8 @@ def test_transitions_stream():
|
||||
@require_package("grpc")
|
||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
||||
def test_interactions_stream():
|
||||
from lerobot.common.transport.utils import bytes_to_python_object, python_object_to_bytes
|
||||
from lerobot.scripts.rl.actor import interactions_stream
|
||||
from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes
|
||||
|
||||
"""Test interactions stream functionality."""
|
||||
shutdown_event = Event()
|
||||
|
||||
@@ -22,9 +22,9 @@ import pytest
|
||||
import torch
|
||||
from torch.multiprocessing import Event, Queue
|
||||
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.common.utils.transition import Transition
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.utils.transition import Transition
|
||||
from tests.utils import require_package
|
||||
|
||||
|
||||
@@ -90,7 +90,6 @@ def cfg():
|
||||
@require_package("grpc")
|
||||
@pytest.mark.timeout(10) # force cross-platform watchdog
|
||||
def test_end_to_end_transitions_flow(cfg):
|
||||
from lerobot.common.transport.utils import bytes_to_transitions
|
||||
from lerobot.scripts.rl.actor import (
|
||||
establish_learner_connection,
|
||||
learner_service_client,
|
||||
@@ -98,6 +97,7 @@ def test_end_to_end_transitions_flow(cfg):
|
||||
send_transitions,
|
||||
)
|
||||
from lerobot.scripts.rl.learner import start_learner
|
||||
from lerobot.transport.utils import bytes_to_transitions
|
||||
from tests.transport.test_transport_utils import assert_transitions_equal
|
||||
|
||||
"""Test complete transitions flow from actor to learner."""
|
||||
@@ -152,13 +152,13 @@ def test_end_to_end_transitions_flow(cfg):
|
||||
@require_package("grpc")
|
||||
@pytest.mark.timeout(10)
|
||||
def test_end_to_end_interactions_flow(cfg):
|
||||
from lerobot.common.transport.utils import bytes_to_python_object, python_object_to_bytes
|
||||
from lerobot.scripts.rl.actor import (
|
||||
establish_learner_connection,
|
||||
learner_service_client,
|
||||
send_interactions,
|
||||
)
|
||||
from lerobot.scripts.rl.learner import start_learner
|
||||
from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes
|
||||
|
||||
"""Test complete interactions flow from actor to learner."""
|
||||
# Queues for actor-learner communication
|
||||
@@ -226,9 +226,9 @@ def test_end_to_end_interactions_flow(cfg):
|
||||
@pytest.mark.parametrize("data_size", ["small", "large"])
|
||||
@pytest.mark.timeout(10)
|
||||
def test_end_to_end_parameters_flow(cfg, data_size):
|
||||
from lerobot.common.transport.utils import bytes_to_state_dict, state_to_bytes
|
||||
from lerobot.scripts.rl.actor import establish_learner_connection, learner_service_client, receive_policy
|
||||
from lerobot.scripts.rl.learner import start_learner
|
||||
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
|
||||
|
||||
"""Test complete parameter flow from learner to actor, with small and large data."""
|
||||
# Actor's local queue to receive params
|
||||
|
||||
@@ -50,8 +50,8 @@ def create_learner_service_stub(
|
||||
):
|
||||
import grpc
|
||||
|
||||
from lerobot.common.transport import services_pb2_grpc # generated from .proto
|
||||
from lerobot.scripts.rl.learner_service import LearnerService
|
||||
from lerobot.transport import services_pb2_grpc # generated from .proto
|
||||
|
||||
"""Fixture to start a LearnerService gRPC server and provide a connected stub."""
|
||||
|
||||
@@ -83,7 +83,7 @@ def close_learner_service_stub(channel, server):
|
||||
|
||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
||||
def test_ready_method(learner_service_stub):
|
||||
from lerobot.common.transport import services_pb2
|
||||
from lerobot.transport import services_pb2
|
||||
|
||||
"""Test the ready method of the UserService."""
|
||||
request = services_pb2.Empty()
|
||||
@@ -94,7 +94,7 @@ def test_ready_method(learner_service_stub):
|
||||
@require_package("grpc")
|
||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
||||
def test_send_interactions():
|
||||
from lerobot.common.transport import services_pb2
|
||||
from lerobot.transport import services_pb2
|
||||
|
||||
shutdown_event = Event()
|
||||
|
||||
@@ -138,7 +138,7 @@ def test_send_interactions():
|
||||
@require_package("grpc")
|
||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
||||
def test_send_transitions():
|
||||
from lerobot.common.transport import services_pb2
|
||||
from lerobot.transport import services_pb2
|
||||
|
||||
"""Test the SendTransitions method with various transition data."""
|
||||
shutdown_event = Event()
|
||||
@@ -184,7 +184,7 @@ def test_send_transitions():
|
||||
@require_package("grpc")
|
||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
||||
def test_send_transitions_empty_stream():
|
||||
from lerobot.common.transport import services_pb2
|
||||
from lerobot.transport import services_pb2
|
||||
|
||||
"""Test SendTransitions with empty stream."""
|
||||
shutdown_event = Event()
|
||||
@@ -214,7 +214,7 @@ def test_send_transitions_empty_stream():
|
||||
def test_stream_parameters():
|
||||
import time
|
||||
|
||||
from lerobot.common.transport import services_pb2
|
||||
from lerobot.transport import services_pb2
|
||||
|
||||
"""Test the StreamParameters method."""
|
||||
shutdown_event = Event()
|
||||
@@ -270,7 +270,7 @@ def test_stream_parameters():
|
||||
@require_package("grpc")
|
||||
@pytest.mark.timeout(3) # force cross-platform watchdog
|
||||
def test_stream_parameters_with_shutdown():
|
||||
from lerobot.common.transport import services_pb2
|
||||
from lerobot.transport import services_pb2
|
||||
|
||||
"""Test StreamParameters handles shutdown gracefully."""
|
||||
shutdown_event = Event()
|
||||
@@ -325,7 +325,7 @@ def test_stream_parameters_waits_and_retries_on_empty_queue():
|
||||
import threading
|
||||
import time
|
||||
|
||||
from lerobot.common.transport import services_pb2
|
||||
from lerobot.transport import services_pb2
|
||||
|
||||
"""Test that StreamParameters waits and retries when the queue is empty."""
|
||||
shutdown_event = Event()
|
||||
|
||||
@@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.robots.so100_follower import (
|
||||
from lerobot.robots.so100_follower import (
|
||||
SO100Follower,
|
||||
SO100FollowerConfig,
|
||||
)
|
||||
@@ -50,7 +50,7 @@ def follower():
|
||||
|
||||
with (
|
||||
patch(
|
||||
"lerobot.common.robots.so100_follower.so100_follower.FeetechMotorsBus",
|
||||
"lerobot.robots.so100_follower.so100_follower.FeetechMotorsBus",
|
||||
side_effect=_bus_side_effect,
|
||||
),
|
||||
patch.object(SO100Follower, "configure", lambda self: None),
|
||||
|
||||
@@ -19,10 +19,10 @@ import gymnasium as gym
|
||||
import pytest
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy
|
||||
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
from lerobot.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
||||
from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy
|
||||
from tests.utils import require_env
|
||||
|
||||
|
||||
|
||||
@@ -21,13 +21,13 @@ from pickle import UnpicklingError
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.utils.transition import Transition
|
||||
from lerobot.utils.transition import Transition
|
||||
from tests.utils import require_cuda, require_package
|
||||
|
||||
|
||||
@require_package("grpc")
|
||||
def test_bytes_buffer_size_empty_buffer():
|
||||
from lerobot.common.transport.utils import bytes_buffer_size
|
||||
from lerobot.transport.utils import bytes_buffer_size
|
||||
|
||||
"""Test with an empty buffer."""
|
||||
buffer = io.BytesIO()
|
||||
@@ -38,7 +38,7 @@ def test_bytes_buffer_size_empty_buffer():
|
||||
|
||||
@require_package("grpc")
|
||||
def test_bytes_buffer_size_small_buffer():
|
||||
from lerobot.common.transport.utils import bytes_buffer_size
|
||||
from lerobot.transport.utils import bytes_buffer_size
|
||||
|
||||
"""Test with a small buffer."""
|
||||
buffer = io.BytesIO(b"Hello, World!")
|
||||
@@ -48,7 +48,7 @@ def test_bytes_buffer_size_small_buffer():
|
||||
|
||||
@require_package("grpc")
|
||||
def test_bytes_buffer_size_large_buffer():
|
||||
from lerobot.common.transport.utils import CHUNK_SIZE, bytes_buffer_size
|
||||
from lerobot.transport.utils import CHUNK_SIZE, bytes_buffer_size
|
||||
|
||||
"""Test with a large buffer."""
|
||||
data = b"x" * (CHUNK_SIZE * 2 + 1000)
|
||||
@@ -59,7 +59,7 @@ def test_bytes_buffer_size_large_buffer():
|
||||
|
||||
@require_package("grpc")
|
||||
def test_send_bytes_in_chunks_empty_data():
|
||||
from lerobot.common.transport.utils import send_bytes_in_chunks, services_pb2
|
||||
from lerobot.transport.utils import send_bytes_in_chunks, services_pb2
|
||||
|
||||
"""Test sending empty data."""
|
||||
message_class = services_pb2.InteractionMessage
|
||||
@@ -69,7 +69,7 @@ def test_send_bytes_in_chunks_empty_data():
|
||||
|
||||
@require_package("grpc")
|
||||
def test_single_chunk_small_data():
|
||||
from lerobot.common.transport.utils import send_bytes_in_chunks, services_pb2
|
||||
from lerobot.transport.utils import send_bytes_in_chunks, services_pb2
|
||||
|
||||
"""Test data that fits in a single chunk."""
|
||||
data = b"Some data"
|
||||
@@ -83,7 +83,7 @@ def test_single_chunk_small_data():
|
||||
|
||||
@require_package("grpc")
|
||||
def test_not_silent_mode():
|
||||
from lerobot.common.transport.utils import send_bytes_in_chunks, services_pb2
|
||||
from lerobot.transport.utils import send_bytes_in_chunks, services_pb2
|
||||
|
||||
"""Test not silent mode."""
|
||||
data = b"Some data"
|
||||
@@ -95,7 +95,7 @@ def test_not_silent_mode():
|
||||
|
||||
@require_package("grpc")
|
||||
def test_send_bytes_in_chunks_large_data():
|
||||
from lerobot.common.transport.utils import CHUNK_SIZE, send_bytes_in_chunks, services_pb2
|
||||
from lerobot.transport.utils import CHUNK_SIZE, send_bytes_in_chunks, services_pb2
|
||||
|
||||
"""Test sending large data."""
|
||||
data = b"x" * (CHUNK_SIZE * 2 + 1000)
|
||||
@@ -112,7 +112,7 @@ def test_send_bytes_in_chunks_large_data():
|
||||
|
||||
@require_package("grpc")
|
||||
def test_send_bytes_in_chunks_large_data_with_exact_chunk_size():
|
||||
from lerobot.common.transport.utils import CHUNK_SIZE, send_bytes_in_chunks, services_pb2
|
||||
from lerobot.transport.utils import CHUNK_SIZE, send_bytes_in_chunks, services_pb2
|
||||
|
||||
"""Test sending large data with exact chunk size."""
|
||||
data = b"x" * CHUNK_SIZE
|
||||
@@ -125,7 +125,7 @@ def test_send_bytes_in_chunks_large_data_with_exact_chunk_size():
|
||||
|
||||
@require_package("grpc")
|
||||
def test_receive_bytes_in_chunks_empty_data():
|
||||
from lerobot.common.transport.utils import receive_bytes_in_chunks
|
||||
from lerobot.transport.utils import receive_bytes_in_chunks
|
||||
|
||||
"""Test receiving empty data."""
|
||||
queue = Queue()
|
||||
@@ -139,7 +139,7 @@ def test_receive_bytes_in_chunks_empty_data():
|
||||
|
||||
@require_package("grpc")
|
||||
def test_receive_bytes_in_chunks_single_chunk():
|
||||
from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2
|
||||
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
|
||||
|
||||
"""Test receiving a single chunk message."""
|
||||
queue = Queue()
|
||||
@@ -158,7 +158,7 @@ def test_receive_bytes_in_chunks_single_chunk():
|
||||
|
||||
@require_package("grpc")
|
||||
def test_receive_bytes_in_chunks_single_not_end_chunk():
|
||||
from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2
|
||||
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
|
||||
|
||||
"""Test receiving a single chunk message."""
|
||||
queue = Queue()
|
||||
@@ -176,7 +176,7 @@ def test_receive_bytes_in_chunks_single_not_end_chunk():
|
||||
|
||||
@require_package("grpc")
|
||||
def test_receive_bytes_in_chunks_multiple_chunks():
|
||||
from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2
|
||||
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
|
||||
|
||||
"""Test receiving a multi-chunk message."""
|
||||
queue = Queue()
|
||||
@@ -200,7 +200,7 @@ def test_receive_bytes_in_chunks_multiple_chunks():
|
||||
|
||||
@require_package("grpc")
|
||||
def test_receive_bytes_in_chunks_multiple_messages():
|
||||
from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2
|
||||
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
|
||||
|
||||
"""Test receiving multiple complete messages in sequence."""
|
||||
queue = Queue()
|
||||
@@ -236,7 +236,7 @@ def test_receive_bytes_in_chunks_multiple_messages():
|
||||
|
||||
@require_package("grpc")
|
||||
def test_receive_bytes_in_chunks_shutdown_during_receive():
|
||||
from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2
|
||||
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
|
||||
|
||||
"""Test that shutdown event stops receiving mid-stream."""
|
||||
queue = Queue()
|
||||
@@ -260,7 +260,7 @@ def test_receive_bytes_in_chunks_shutdown_during_receive():
|
||||
|
||||
@require_package("grpc")
|
||||
def test_receive_bytes_in_chunks_only_begin_chunk():
|
||||
from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2
|
||||
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
|
||||
|
||||
"""Test receiving only a BEGIN chunk without END."""
|
||||
queue = Queue()
|
||||
@@ -280,7 +280,7 @@ def test_receive_bytes_in_chunks_only_begin_chunk():
|
||||
|
||||
@require_package("grpc")
|
||||
def test_receive_bytes_in_chunks_missing_begin():
|
||||
from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2
|
||||
from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2
|
||||
|
||||
"""Test receiving chunks starting with MIDDLE instead of BEGIN."""
|
||||
queue = Queue()
|
||||
@@ -304,7 +304,7 @@ def test_receive_bytes_in_chunks_missing_begin():
|
||||
# Tests for state_to_bytes and bytes_to_state_dict
|
||||
@require_package("grpc")
|
||||
def test_state_to_bytes_empty_dict():
|
||||
from lerobot.common.transport.utils import bytes_to_state_dict, state_to_bytes
|
||||
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
|
||||
|
||||
"""Test converting empty state dict to bytes."""
|
||||
state_dict = {}
|
||||
@@ -315,7 +315,7 @@ def test_state_to_bytes_empty_dict():
|
||||
|
||||
@require_package("grpc")
|
||||
def test_bytes_to_state_dict_empty_data():
|
||||
from lerobot.common.transport.utils import bytes_to_state_dict
|
||||
from lerobot.transport.utils import bytes_to_state_dict
|
||||
|
||||
"""Test converting empty data to state dict."""
|
||||
with pytest.raises(EOFError):
|
||||
@@ -324,7 +324,7 @@ def test_bytes_to_state_dict_empty_data():
|
||||
|
||||
@require_package("grpc")
|
||||
def test_state_to_bytes_simple_dict():
|
||||
from lerobot.common.transport.utils import bytes_to_state_dict, state_to_bytes
|
||||
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
|
||||
|
||||
"""Test converting simple state dict to bytes."""
|
||||
state_dict = {
|
||||
@@ -348,7 +348,7 @@ def test_state_to_bytes_simple_dict():
|
||||
|
||||
@require_package("grpc")
|
||||
def test_state_to_bytes_various_dtypes():
|
||||
from lerobot.common.transport.utils import bytes_to_state_dict, state_to_bytes
|
||||
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
|
||||
|
||||
"""Test converting state dict with various tensor dtypes."""
|
||||
state_dict = {
|
||||
@@ -373,7 +373,7 @@ def test_state_to_bytes_various_dtypes():
|
||||
|
||||
@require_package("grpc")
|
||||
def test_bytes_to_state_dict_invalid_data():
|
||||
from lerobot.common.transport.utils import bytes_to_state_dict
|
||||
from lerobot.transport.utils import bytes_to_state_dict
|
||||
|
||||
"""Test bytes_to_state_dict with invalid data."""
|
||||
with pytest.raises(UnpicklingError):
|
||||
@@ -383,7 +383,7 @@ def test_bytes_to_state_dict_invalid_data():
|
||||
@require_cuda
|
||||
@require_package("grpc")
|
||||
def test_state_to_bytes_various_dtypes_cuda():
|
||||
from lerobot.common.transport.utils import bytes_to_state_dict, state_to_bytes
|
||||
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
|
||||
|
||||
"""Test converting state dict with various tensor dtypes."""
|
||||
state_dict = {
|
||||
@@ -408,7 +408,7 @@ def test_state_to_bytes_various_dtypes_cuda():
|
||||
|
||||
@require_package("grpc")
|
||||
def test_python_object_to_bytes_none():
|
||||
from lerobot.common.transport.utils import bytes_to_python_object, python_object_to_bytes
|
||||
from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes
|
||||
|
||||
"""Test converting None to bytes."""
|
||||
obj = None
|
||||
@@ -440,7 +440,7 @@ def test_python_object_to_bytes_none():
|
||||
)
|
||||
@require_package("grpc")
|
||||
def test_python_object_to_bytes_simple_types(obj):
|
||||
from lerobot.common.transport.utils import bytes_to_python_object, python_object_to_bytes
|
||||
from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes
|
||||
|
||||
"""Test converting simple Python types."""
|
||||
data = python_object_to_bytes(obj)
|
||||
@@ -451,7 +451,7 @@ def test_python_object_to_bytes_simple_types(obj):
|
||||
|
||||
@require_package("grpc")
|
||||
def test_python_object_to_bytes_with_tensors():
|
||||
from lerobot.common.transport.utils import bytes_to_python_object, python_object_to_bytes
|
||||
from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes
|
||||
|
||||
"""Test converting objects containing PyTorch tensors."""
|
||||
obj = {
|
||||
@@ -476,7 +476,7 @@ def test_python_object_to_bytes_with_tensors():
|
||||
|
||||
@require_package("grpc")
|
||||
def test_transitions_to_bytes_empty_list():
|
||||
from lerobot.common.transport.utils import bytes_to_transitions, transitions_to_bytes
|
||||
from lerobot.transport.utils import bytes_to_transitions, transitions_to_bytes
|
||||
|
||||
"""Test converting empty transitions list."""
|
||||
transitions = []
|
||||
@@ -488,7 +488,7 @@ def test_transitions_to_bytes_empty_list():
|
||||
|
||||
@require_package("grpc")
|
||||
def test_transitions_to_bytes_single_transition():
|
||||
from lerobot.common.transport.utils import bytes_to_transitions, transitions_to_bytes
|
||||
from lerobot.transport.utils import bytes_to_transitions, transitions_to_bytes
|
||||
|
||||
"""Test converting a single transition."""
|
||||
transition = Transition(
|
||||
@@ -528,7 +528,7 @@ def assert_observation_equal(o1: dict, o2: dict):
|
||||
|
||||
@require_package("grpc")
|
||||
def test_transitions_to_bytes_multiple_transitions():
|
||||
from lerobot.common.transport.utils import bytes_to_transitions, transitions_to_bytes
|
||||
from lerobot.transport.utils import bytes_to_transitions, transitions_to_bytes
|
||||
|
||||
"""Test converting multiple transitions."""
|
||||
transitions = []
|
||||
@@ -552,7 +552,7 @@ def test_transitions_to_bytes_multiple_transitions():
|
||||
|
||||
@require_package("grpc")
|
||||
def test_receive_bytes_in_chunks_unknown_state():
|
||||
from lerobot.common.transport.utils import receive_bytes_in_chunks
|
||||
from lerobot.transport.utils import receive_bytes_in_chunks
|
||||
|
||||
"""Test receive_bytes_in_chunks with an unknown transfer state."""
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from lerobot import available_cameras, available_motors, available_robots
|
||||
from lerobot.common.utils.import_utils import is_package_available
|
||||
from lerobot.utils.import_utils import is_package_available
|
||||
|
||||
DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from lerobot.common.utils.encoding_utils import (
|
||||
from lerobot.utils.encoding_utils import (
|
||||
decode_sign_magnitude,
|
||||
decode_twos_complement,
|
||||
encode_sign_magnitude,
|
||||
|
||||
@@ -17,7 +17,7 @@ from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.utils.io_utils import deserialize_json_into_object
|
||||
from lerobot.utils.io_utils import deserialize_json_into_object
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
import pytest
|
||||
|
||||
from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -22,7 +22,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.utils.process import ProcessSignalHandler
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
|
||||
|
||||
# Fixture to reset shutdown_event_counter and original signal handlers before and after each test
|
||||
|
||||
@@ -18,7 +18,7 @@ import threading
|
||||
import time
|
||||
from queue import Queue
|
||||
|
||||
from lerobot.common.utils.queue import get_last_item_from_queue
|
||||
from lerobot.utils.queue import get_last_item_from_queue
|
||||
|
||||
|
||||
def test_get_last_item_single_item():
|
||||
|
||||
@@ -17,7 +17,7 @@ import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.utils.random_utils import (
|
||||
from lerobot.utils.random_utils import (
|
||||
deserialize_numpy_rng_state,
|
||||
deserialize_python_rng_state,
|
||||
deserialize_rng_state,
|
||||
|
||||
@@ -20,8 +20,8 @@ from typing import Callable
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.utils.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from lerobot.common.constants import (
|
||||
from lerobot.constants import (
|
||||
CHECKPOINTS_DIR,
|
||||
LAST_CHECKPOINT_LINK,
|
||||
OPTIMIZER_PARAM_GROUPS,
|
||||
@@ -24,7 +24,7 @@ from lerobot.common.constants import (
|
||||
TRAINING_STATE_DIR,
|
||||
TRAINING_STEP,
|
||||
)
|
||||
from lerobot.common.utils.train_utils import (
|
||||
from lerobot.utils.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
get_step_identifier,
|
||||
load_training_state,
|
||||
@@ -69,7 +69,7 @@ def test_update_last_checkpoint(tmp_path):
|
||||
assert last_checkpoint.resolve() == checkpoint
|
||||
|
||||
|
||||
@patch("lerobot.common.utils.train_utils.save_training_state")
|
||||
@patch("lerobot.utils.train_utils.save_training_state")
|
||||
def test_save_checkpoint(mock_save_training_state, tmp_path, optimizer):
|
||||
policy = Mock()
|
||||
cfg = Mock()
|
||||
|
||||
Reference in New Issue
Block a user