Compare commits
15 Commits
user/pepij
...
torchcodec
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c7a36bb661 | ||
|
|
9466da0ddf | ||
|
|
bb7542d799 | ||
|
|
1a1740d90d | ||
|
|
0b379e9c4e | ||
|
|
2295e6c45b | ||
|
|
920079230e | ||
|
|
744906098a | ||
|
|
e1732b4954 | ||
|
|
c03b0db8aa | ||
|
|
a8fcd3512d | ||
|
|
a963dba256 | ||
|
|
2f9cbfbc4f | ||
|
|
e8126dc3d6 | ||
|
|
4e2dc91e59 |
@@ -73,7 +73,7 @@ pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
!tests/artifacts
|
||||
!tests/data
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
|
||||
2
.github/workflows/test-docker-build.yml
vendored
2
.github/workflows/test-docker-build.yml
vendored
@@ -41,7 +41,7 @@ jobs:
|
||||
|
||||
- name: Get changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@3f54ebb830831fc121d3263c1857cfbdc310cdb9 #v42
|
||||
uses: tj-actions/changed-files@v44
|
||||
with:
|
||||
files: docker/**
|
||||
json: "true"
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -78,7 +78,7 @@ pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
!tests/artifacts
|
||||
!tests/data
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
exclude: "tests/artifacts/.*\\.safetensors$"
|
||||
exclude: ^(tests/data)
|
||||
default_language_version:
|
||||
python: python3.10
|
||||
repos:
|
||||
|
||||
@@ -291,7 +291,7 @@ sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
```
|
||||
|
||||
Pull artifacts if they're not in [tests/artifacts](tests/artifacts)
|
||||
Pull artifacts if they're not in [tests/data](tests/data)
|
||||
```bash
|
||||
git lfs pull
|
||||
```
|
||||
|
||||
@@ -51,7 +51,7 @@ For a comprehensive list and documentation of these parameters, see the ffmpeg d
|
||||
### Decoding parameters
|
||||
**Decoder**
|
||||
We tested two video decoding backends from torchvision:
|
||||
- `pyav`
|
||||
- `pyav` (default)
|
||||
- `video_reader` (requires to build torchvision from source)
|
||||
|
||||
**Requested timestamps**
|
||||
|
||||
@@ -583,13 +583,6 @@ Let's explain it:
|
||||
|
||||
Training should take several hours. You will find checkpoints in `outputs/train/act_so100_test/checkpoints`.
|
||||
|
||||
To resume training from a checkpoint, below is an example command to resume from `last` checkpoint of the `act_so100_test` policy:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--config_path=outputs/train/act_so100_test/checkpoints/last/pretrained_model/train_config.json \
|
||||
--resume=true
|
||||
```
|
||||
|
||||
## K. Evaluate your policy
|
||||
|
||||
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
||||
|
||||
@@ -69,7 +69,6 @@ from lerobot.common.datasets.video_utils import (
|
||||
VideoFrame,
|
||||
decode_video_frames,
|
||||
encode_video_frames,
|
||||
get_safe_default_codec,
|
||||
get_video_info,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
@@ -463,7 +462,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
||||
video files are already present on local disk, they won't be downloaded again. Defaults to
|
||||
True.
|
||||
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
|
||||
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec.
|
||||
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
||||
"""
|
||||
super().__init__()
|
||||
@@ -474,7 +473,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.episodes = episodes
|
||||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
self.video_backend = video_backend if video_backend else "torchcodec"
|
||||
self.delta_indices = None
|
||||
|
||||
# Unused attributes
|
||||
@@ -1028,7 +1027,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.delta_timestamps = None
|
||||
obj.delta_indices = None
|
||||
obj.episode_data_index = None
|
||||
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||
obj.video_backend = video_backend if video_backend is not None else "torchcodec"
|
||||
return obj
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
@@ -28,23 +27,14 @@ import torch
|
||||
import torchvision
|
||||
from datasets.features.features import register_feature
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def get_safe_default_codec():
|
||||
if importlib.util.find_spec("torchcodec"):
|
||||
return "torchcodec"
|
||||
else:
|
||||
logging.warning(
|
||||
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
|
||||
)
|
||||
return "pyav"
|
||||
from torchcodec.decoders import VideoDecoder
|
||||
|
||||
|
||||
def decode_video_frames(
|
||||
video_path: Path | str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
backend: str | None = None,
|
||||
backend: str = "torchcodec",
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Decodes video frames using the specified backend.
|
||||
@@ -53,15 +43,13 @@ def decode_video_frames(
|
||||
video_path (Path): Path to the video file.
|
||||
timestamps (list[float]): List of timestamps to extract frames.
|
||||
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
|
||||
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav"..
|
||||
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec".
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Decoded frames.
|
||||
|
||||
Currently supports torchcodec on cpu and pyav.
|
||||
"""
|
||||
if backend is None:
|
||||
backend = get_safe_default_codec()
|
||||
if backend == "torchcodec":
|
||||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
|
||||
elif backend in ["pyav", "video_reader"]:
|
||||
@@ -185,14 +173,9 @@ def decode_video_frames_torchcodec(
|
||||
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
||||
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
||||
"""
|
||||
|
||||
if importlib.util.find_spec("torchcodec"):
|
||||
from torchcodec.decoders import VideoDecoder
|
||||
else:
|
||||
raise ImportError("torchcodec is required but not available.")
|
||||
|
||||
video_path = str(video_path)
|
||||
# initialize video decoder
|
||||
decoder = VideoDecoder(video_path, device=device, seek_mode="approximate")
|
||||
decoder = VideoDecoder(video_path, device=device)
|
||||
loaded_frames = []
|
||||
loaded_ts = []
|
||||
# get metadata for frame information
|
||||
|
||||
@@ -119,7 +119,9 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = [batch[key] for key in self.config.image_features]
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
|
||||
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
|
||||
# we are ensembling over.
|
||||
@@ -147,8 +149,9 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = [batch[key] for key in self.config.image_features]
|
||||
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
@@ -410,10 +413,11 @@ class ACT(nn.Module):
|
||||
"actions must be provided when using the variational objective in training mode."
|
||||
)
|
||||
|
||||
if "observation.images" in batch:
|
||||
batch_size = batch["observation.images"][0].shape[0]
|
||||
else:
|
||||
batch_size = batch["observation.environment_state"].shape[0]
|
||||
batch_size = (
|
||||
batch["observation.images"]
|
||||
if "observation.images" in batch
|
||||
else batch["observation.environment_state"]
|
||||
).shape[0]
|
||||
|
||||
# Prepare the latent for input to the transformer encoder.
|
||||
if self.config.use_vae and "action" in batch:
|
||||
@@ -486,21 +490,20 @@ class ACT(nn.Module):
|
||||
all_cam_features = []
|
||||
all_cam_pos_embeds = []
|
||||
|
||||
# For a list of images, the H and W may vary but H*W is constant.
|
||||
for img in batch["observation.images"]:
|
||||
cam_features = self.backbone(img)["feature_map"]
|
||||
for cam_index in range(batch["observation.images"].shape[-4]):
|
||||
cam_features = self.backbone(batch["observation.images"][:, cam_index])["feature_map"]
|
||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use
|
||||
# buffer
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||
cam_features = self.encoder_img_feat_input_proj(cam_features)
|
||||
|
||||
# Rearrange features to (sequence, batch, dim).
|
||||
cam_features = einops.rearrange(cam_features, "b c h w -> (h w) b c")
|
||||
cam_pos_embed = einops.rearrange(cam_pos_embed, "b c h w -> (h w) b c")
|
||||
|
||||
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
||||
all_cam_features.append(cam_features)
|
||||
all_cam_pos_embeds.append(cam_pos_embed)
|
||||
|
||||
encoder_in_tokens.extend(torch.cat(all_cam_features, axis=0))
|
||||
encoder_in_pos_embed.extend(torch.cat(all_cam_pos_embeds, axis=0))
|
||||
# Concatenate camera observation feature maps and positional embeddings along the width dimension,
|
||||
# and move to (sequence, batch, dim).
|
||||
all_cam_features = torch.cat(all_cam_features, axis=-1)
|
||||
encoder_in_tokens.extend(einops.rearrange(all_cam_features, "b c h w -> (h w) b c"))
|
||||
all_cam_pos_embeds = torch.cat(all_cam_pos_embeds, axis=-1)
|
||||
encoder_in_pos_embed.extend(einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c"))
|
||||
|
||||
# Stack all tokens along the sequence dimension.
|
||||
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
|
||||
|
||||
@@ -48,7 +48,7 @@ def find_cameras(raise_when_empty=True, mock=False) -> list[dict]:
|
||||
connected to the computer.
|
||||
"""
|
||||
if mock:
|
||||
import tests.cameras.mock_pyrealsense2 as rs
|
||||
import tests.mock_pyrealsense2 as rs
|
||||
else:
|
||||
import pyrealsense2 as rs
|
||||
|
||||
@@ -100,7 +100,7 @@ def save_images_from_cameras(
|
||||
serial_numbers = [cam["serial_number"] for cam in camera_infos]
|
||||
|
||||
if mock:
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
import tests.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
@@ -253,7 +253,7 @@ class IntelRealSenseCamera:
|
||||
self.logs = {}
|
||||
|
||||
if self.mock:
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
import tests.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
@@ -287,7 +287,7 @@ class IntelRealSenseCamera:
|
||||
)
|
||||
|
||||
if self.mock:
|
||||
import tests.cameras.mock_pyrealsense2 as rs
|
||||
import tests.mock_pyrealsense2 as rs
|
||||
else:
|
||||
import pyrealsense2 as rs
|
||||
|
||||
@@ -375,7 +375,7 @@ class IntelRealSenseCamera:
|
||||
)
|
||||
|
||||
if self.mock:
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
import tests.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
|
||||
@@ -80,7 +80,7 @@ def _find_cameras(
|
||||
possible_camera_ids: list[int | str], raise_when_empty=False, mock=False
|
||||
) -> list[int | str]:
|
||||
if mock:
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
import tests.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
@@ -269,7 +269,7 @@ class OpenCVCamera:
|
||||
self.logs = {}
|
||||
|
||||
if self.mock:
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
import tests.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
@@ -286,7 +286,7 @@ class OpenCVCamera:
|
||||
raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")
|
||||
|
||||
if self.mock:
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
import tests.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
@@ -398,7 +398,7 @@ class OpenCVCamera:
|
||||
# so we convert the image color from BGR to RGB.
|
||||
if requested_color_mode == "rgb":
|
||||
if self.mock:
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
import tests.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
|
||||
@@ -332,7 +332,7 @@ class DynamixelMotorsBus:
|
||||
)
|
||||
|
||||
if self.mock:
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
import tests.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
@@ -356,7 +356,7 @@ class DynamixelMotorsBus:
|
||||
|
||||
def reconnect(self):
|
||||
if self.mock:
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
import tests.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
@@ -646,7 +646,7 @@ class DynamixelMotorsBus:
|
||||
|
||||
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
|
||||
if self.mock:
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
import tests.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
@@ -691,7 +691,7 @@ class DynamixelMotorsBus:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.mock:
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
import tests.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
@@ -757,7 +757,7 @@ class DynamixelMotorsBus:
|
||||
|
||||
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
|
||||
if self.mock:
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
import tests.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
@@ -793,7 +793,7 @@ class DynamixelMotorsBus:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.mock:
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
import tests.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
|
||||
@@ -313,7 +313,7 @@ class FeetechMotorsBus:
|
||||
)
|
||||
|
||||
if self.mock:
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
import tests.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
@@ -337,7 +337,7 @@ class FeetechMotorsBus:
|
||||
|
||||
def reconnect(self):
|
||||
if self.mock:
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
import tests.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
@@ -664,7 +664,7 @@ class FeetechMotorsBus:
|
||||
|
||||
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
|
||||
if self.mock:
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
import tests.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
@@ -702,7 +702,7 @@ class FeetechMotorsBus:
|
||||
|
||||
def read(self, data_name, motor_names: str | list[str] | None = None):
|
||||
if self.mock:
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
import tests.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
@@ -782,7 +782,7 @@ class FeetechMotorsBus:
|
||||
|
||||
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
|
||||
if self.mock:
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
import tests.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
@@ -818,7 +818,7 @@ class FeetechMotorsBus:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.mock:
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
import tests.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
|
||||
@@ -69,13 +69,7 @@ class WandBLogger:
|
||||
os.environ["WANDB_SILENT"] = "True"
|
||||
import wandb
|
||||
|
||||
wandb_run_id = (
|
||||
cfg.wandb.run_id
|
||||
if cfg.wandb.run_id
|
||||
else get_wandb_run_id_from_filesystem(self.log_dir)
|
||||
if cfg.resume
|
||||
else None
|
||||
)
|
||||
wandb_run_id = get_wandb_run_id_from_filesystem(self.log_dir) if cfg.resume else None
|
||||
wandb.init(
|
||||
id=wandb_run_id,
|
||||
project=self.cfg.project,
|
||||
|
||||
@@ -20,7 +20,6 @@ from lerobot.common import (
|
||||
policies, # noqa: F401
|
||||
)
|
||||
from lerobot.common.datasets.transforms import ImageTransformsConfig
|
||||
from lerobot.common.datasets.video_utils import get_safe_default_codec
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -36,7 +35,7 @@ class DatasetConfig:
|
||||
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||
revision: str | None = None
|
||||
use_imagenet_stats: bool = True
|
||||
video_backend: str = field(default_factory=get_safe_default_codec)
|
||||
video_backend: str = "pyav"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -47,7 +46,6 @@ class WandBConfig:
|
||||
project: str = "lerobot"
|
||||
entity: str | None = None
|
||||
notes: str | None = None
|
||||
run_id: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -79,9 +79,7 @@ class TrainPipelineConfig(HubMixin):
|
||||
# The entire train config is already loaded, we just need to get the checkpoint dir
|
||||
config_path = parser.parse_arg("config_path")
|
||||
if not config_path:
|
||||
raise ValueError(
|
||||
f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
|
||||
)
|
||||
raise ValueError("A config_path is expected when resuming a run.")
|
||||
if not Path(config_path).resolve().exists():
|
||||
raise NotADirectoryError(
|
||||
f"{config_path=} is expected to be a local path. "
|
||||
|
||||
@@ -69,7 +69,7 @@ dependencies = [
|
||||
"rerun-sdk>=0.21.0",
|
||||
"termcolor>=2.4.0",
|
||||
"torch>=2.2.1",
|
||||
"torchcodec>=0.2.1 ; sys_platform != 'linux' or (sys_platform == 'linux' and platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')",
|
||||
"torchcodec>=0.2.1",
|
||||
"torchvision>=0.21.0",
|
||||
"wandb>=0.16.3",
|
||||
"zarr>=2.17.0",
|
||||
@@ -103,7 +103,30 @@ requires-poetry = ">=2.1"
|
||||
[tool.ruff]
|
||||
line-length = 110
|
||||
target-version = "py310"
|
||||
exclude = ["tests/artifacts/**/*.safetensors"]
|
||||
exclude = [
|
||||
"tests/data",
|
||||
".bzr",
|
||||
".direnv",
|
||||
".eggs",
|
||||
".git",
|
||||
".git-rewrite",
|
||||
".hg",
|
||||
".mypy_cache",
|
||||
".nox",
|
||||
".pants.d",
|
||||
".pytype",
|
||||
".ruff_cache",
|
||||
".svn",
|
||||
".tox",
|
||||
".venv",
|
||||
"__pypackages__",
|
||||
"_build",
|
||||
"buck-out",
|
||||
"build",
|
||||
"dist",
|
||||
"node_modules",
|
||||
"venv",
|
||||
]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
|
||||
|
||||
38
tests/lerobot/common/datasets/test_utils.py
Normal file
38
tests/lerobot/common/datasets/test_utils.py
Normal file
@@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from huggingface_hub import DatasetCard
|
||||
|
||||
from lerobot.common.datasets.utils import create_lerobot_dataset_card
|
||||
|
||||
|
||||
def test_default_parameters():
|
||||
card = create_lerobot_dataset_card()
|
||||
assert isinstance(card, DatasetCard)
|
||||
assert card.data.tags == ["LeRobot"]
|
||||
assert card.data.task_categories == ["robotics"]
|
||||
assert card.data.configs == [
|
||||
{
|
||||
"config_name": "default",
|
||||
"data_files": "data/*/*.parquet",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_with_tags():
|
||||
tags = ["tag1", "tag2"]
|
||||
card = create_lerobot_dataset_card(tags=tags)
|
||||
assert card.data.tags == ["LeRobot", "tag1", "tag2"]
|
||||
@@ -23,7 +23,7 @@ If you know that your change will break backward compatibility, you should write
|
||||
doesnt need to be merged into the `main` branch. Then you need to run this script and update the tests artifacts.
|
||||
|
||||
Example usage:
|
||||
`python tests/artifacts/datasets/save_dataset_to_safetensors.py`
|
||||
`python tests/scripts/save_dataset_to_safetensors.py`
|
||||
"""
|
||||
|
||||
import shutil
|
||||
@@ -88,4 +88,4 @@ if __name__ == "__main__":
|
||||
"lerobot/nyu_franka_play_dataset",
|
||||
"lerobot/cmu_stretch",
|
||||
]:
|
||||
save_dataset_to_safetensors("tests/artifacts/datasets", repo_id=dataset)
|
||||
save_dataset_to_safetensors("tests/data/save_dataset_to_safetensors", repo_id=dataset)
|
||||
@@ -27,7 +27,7 @@ from lerobot.common.datasets.transforms import (
|
||||
)
|
||||
from lerobot.common.utils.random_utils import seeded_context
|
||||
|
||||
ARTIFACT_DIR = Path("tests/artifacts/image_transforms")
|
||||
ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors")
|
||||
DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"
|
||||
|
||||
|
||||
@@ -141,5 +141,5 @@ if __name__ == "__main__":
|
||||
raise RuntimeError("No policies were provided!")
|
||||
for ds_repo_id, policy, policy_kwargs, file_name_extra in artifacts_cfg:
|
||||
ds_name = ds_repo_id.split("/")[-1]
|
||||
output_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy}_{file_name_extra}"
|
||||
output_dir = Path("tests/data/save_policy_to_safetensors") / f"{ds_name}_{policy}_{file_name_extra}"
|
||||
save_policy_to_safetensors(output_dir, ds_repo_id, policy, policy_kwargs)
|
||||
@@ -146,7 +146,7 @@ def test_camera(request, camera_type, mock):
|
||||
camera.connect()
|
||||
|
||||
if mock:
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
import tests.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
@@ -51,7 +51,7 @@ from lerobot.common.robot_devices.control_configs import (
|
||||
)
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate
|
||||
from tests.robots.test_robots import make_robot
|
||||
from tests.test_robots import make_robot
|
||||
from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
|
||||
|
||||
|
||||
@@ -473,12 +473,12 @@ def test_flatten_unflatten_dict():
|
||||
)
|
||||
@require_x86_64_kernel
|
||||
def test_backward_compatibility(repo_id):
|
||||
"""The artifacts for this test have been generated by `tests/artifacts/datasets/save_dataset_to_safetensors.py`."""
|
||||
"""The artifacts for this test have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""
|
||||
|
||||
# TODO(rcadene, aliberts): remove dataset download
|
||||
dataset = LeRobotDataset(repo_id, episodes=[0])
|
||||
|
||||
test_dir = Path("tests/artifacts/datasets") / repo_id
|
||||
test_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id
|
||||
|
||||
def load_and_compare(i):
|
||||
new_frame = dataset[i] # noqa: B023
|
||||
@@ -23,7 +23,8 @@ from gymnasium.utils.env_checker import check_env
|
||||
import lerobot
|
||||
from lerobot.common.envs.factory import make_env, make_env_config
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from tests.utils import require_env
|
||||
|
||||
from .utils import require_env
|
||||
|
||||
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]
|
||||
|
||||
@@ -33,7 +33,7 @@ from lerobot.scripts.visualize_image_transforms import (
|
||||
save_all_transforms,
|
||||
save_each_transform,
|
||||
)
|
||||
from tests.artifacts.image_transforms.save_image_transforms_to_safetensors import ARTIFACT_DIR
|
||||
from tests.scripts.save_image_transforms_to_safetensors import ARTIFACT_DIR
|
||||
from tests.utils import require_x86_64_kernel
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ from lerobot.common.utils.random_utils import seeded_context
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
|
||||
from tests.scripts.save_policy_to_safetensors import get_policy_stats
|
||||
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
|
||||
|
||||
|
||||
@@ -407,10 +407,12 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs
|
||||
should be updated.
|
||||
4. Check that this test now passes.
|
||||
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
|
||||
6. Remember to stage and commit the resulting changes to `tests/artifacts`.
|
||||
6. Remember to stage and commit the resulting changes to `tests/data`.
|
||||
"""
|
||||
ds_name = ds_repo_id.split("/")[-1]
|
||||
artifact_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}"
|
||||
artifact_dir = (
|
||||
Path("tests/data/save_policy_to_safetensors") / f"{ds_name}_{policy_name}_{file_name_extra}"
|
||||
)
|
||||
saved_output_dict = load_file(artifact_dir / "output_dict.safetensors")
|
||||
saved_grad_stats = load_file(artifact_dir / "grad_stats.safetensors")
|
||||
saved_param_stats = load_file(artifact_dir / "param_stats.safetensors")
|
||||
@@ -1,5 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -13,32 +11,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import DatasetCard
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.common.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch
|
||||
|
||||
|
||||
def test_default_parameters():
|
||||
card = create_lerobot_dataset_card()
|
||||
assert isinstance(card, DatasetCard)
|
||||
assert card.data.tags == ["LeRobot"]
|
||||
assert card.data.task_categories == ["robotics"]
|
||||
assert card.data.configs == [
|
||||
{
|
||||
"config_name": "default",
|
||||
"data_files": "data/*/*.parquet",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_with_tags():
|
||||
tags = ["tag1", "tag2"]
|
||||
card = create_lerobot_dataset_card(tags=tags)
|
||||
assert card.data.tags == ["LeRobot", "tag1", "tag2"]
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
|
||||
|
||||
def test_calculate_episode_data_index():
|
||||
Reference in New Issue
Block a user