Compare commits
63 Commits
main
...
user/rcade
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
eeb7e0d490 | ||
|
|
fac22b1dfe | ||
|
|
2a0c497c87 | ||
|
|
34735ec486 | ||
|
|
57af6b90a2 | ||
|
|
44dc6ab182 | ||
|
|
569c9092dc | ||
|
|
181f20a16d | ||
|
|
bf2fc099cc | ||
|
|
5a16163790 | ||
|
|
875c1fbb2a | ||
|
|
b0f4f4f9a1 | ||
|
|
1664dac300 | ||
|
|
895182b272 | ||
|
|
2c7089e9f9 | ||
|
|
3e2c43296c | ||
|
|
6fc1d0dfc1 | ||
|
|
694a5dbca8 | ||
|
|
61d9d74308 | ||
|
|
87e39da641 | ||
|
|
03d778e672 | ||
|
|
14bc60270f | ||
|
|
0407b0dab2 | ||
|
|
8b4a1c7a7f | ||
|
|
d878ec27d5 | ||
|
|
42325f5990 | ||
|
|
d2a17d2c74 | ||
|
|
3fc351074c | ||
|
|
3e80ec65d2 | ||
|
|
8be46f3760 | ||
|
|
30c7a9e6fc | ||
|
|
d525e1b0f8 | ||
|
|
7a659dbd6b | ||
|
|
1993d29296 | ||
|
|
e615176845 | ||
|
|
3918e118c3 | ||
|
|
17bc32c1e3 | ||
|
|
73692e7459 | ||
|
|
e6aa8cc641 | ||
|
|
5494faf096 | ||
|
|
01f8cede0b | ||
|
|
5d092b1be2 | ||
|
|
7560372404 | ||
|
|
1bc13f39bf | ||
|
|
95de4b7454 | ||
|
|
c2388c59be | ||
|
|
f7a7ed9aa9 | ||
|
|
7411cf2a7a | ||
|
|
2bebdf78a0 | ||
|
|
fb06417a83 | ||
|
|
dd7bce7563 | ||
|
|
68a561570c | ||
|
|
52e760a88e | ||
|
|
798373e7bf | ||
|
|
a0432f1608 | ||
|
|
3ba05d53b9 | ||
|
|
72b2e342ae | ||
|
|
d83d34d9b3 | ||
|
|
3ff789c181 | ||
|
|
6e77a399a2 | ||
|
|
8a7aa50e97 | ||
|
|
47aac0dff7 | ||
|
|
858d49fc04 |
@@ -182,7 +182,7 @@ class OpenCVCamera:
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, camera_index: int, config: OpenCVCameraConfig | None = None, **kwargs):
|
||||
def __init__(self, camera_index: int | str, config: OpenCVCameraConfig | None = None, **kwargs):
|
||||
if config is None:
|
||||
config = OpenCVCameraConfig()
|
||||
# Overwrite config arguments using kwargs
|
||||
@@ -194,11 +194,6 @@ class OpenCVCamera:
|
||||
self.height = config.height
|
||||
self.color_mode = config.color_mode
|
||||
|
||||
if not isinstance(self.camera_index, int):
|
||||
raise ValueError(
|
||||
f"Camera index must be provided as an int, but {self.camera_index} was given instead."
|
||||
)
|
||||
|
||||
self.camera = None
|
||||
self.is_connected = False
|
||||
self.thread = None
|
||||
|
||||
@@ -2,6 +2,7 @@ from pathlib import Path
|
||||
from typing import Protocol
|
||||
|
||||
import cv2
|
||||
import einops
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -39,6 +40,16 @@ def save_depth_image(depth, path, write_shape=False):
|
||||
cv2.imwrite(str(path), depth_image)
|
||||
|
||||
|
||||
def convert_torch_image_to_cv2(tensor, rgb_to_bgr=True):
|
||||
assert tensor.ndim == 3
|
||||
c, h, w = tensor.shape
|
||||
assert c < h and c < w
|
||||
color_image = einops.rearrange(tensor, "c h w -> h w c").numpy()
|
||||
if rgb_to_bgr:
|
||||
color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR)
|
||||
return color_image
|
||||
|
||||
|
||||
# Defines a camera type
|
||||
class Camera(Protocol):
|
||||
def connect(self): ...
|
||||
|
||||
@@ -5,6 +5,7 @@ from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import tqdm
|
||||
from dynamixel_sdk import (
|
||||
COMM_SUCCESS,
|
||||
DXL_HIBYTE,
|
||||
@@ -21,9 +22,11 @@ from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError,
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
|
||||
PROTOCOL_VERSION = 2.0
|
||||
BAUD_RATE = 1_000_000
|
||||
BAUDRATE = 1_000_000
|
||||
TIMEOUT_MS = 1000
|
||||
|
||||
MAX_ID_RANGE = 252
|
||||
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/xl330-m077
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/xl330-m288
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/xl430-w250
|
||||
@@ -86,6 +89,16 @@ X_SERIES_CONTROL_TABLE = {
|
||||
"Present_Temperature": (146, 1),
|
||||
}
|
||||
|
||||
X_SERIES_BAUDRATE_TABLE = {
|
||||
0: 9_600,
|
||||
1: 57_600,
|
||||
2: 115_200,
|
||||
3: 1_000_000,
|
||||
4: 2_000_000,
|
||||
5: 3_000_000,
|
||||
6: 4_000_000,
|
||||
}
|
||||
|
||||
CALIBRATION_REQUIRED = ["Goal_Position", "Present_Position"]
|
||||
CONVERT_UINT32_TO_INT32_REQUIRED = ["Goal_Position", "Present_Position"]
|
||||
|
||||
@@ -98,7 +111,73 @@ MODEL_CONTROL_TABLE = {
|
||||
"xm540-w270": X_SERIES_CONTROL_TABLE,
|
||||
}
|
||||
|
||||
MODEL_RESOLUTION = {
|
||||
"x_series": 4096,
|
||||
"xl330-m077": 4096,
|
||||
"xl330-m288": 4096,
|
||||
"xl430-w250": 4096,
|
||||
"xm430-w350": 4096,
|
||||
"xm540-w270": 4096,
|
||||
}
|
||||
|
||||
MODEL_BAUDRATE_TABLE = {
|
||||
"x_series": X_SERIES_BAUDRATE_TABLE,
|
||||
"xl330-m077": X_SERIES_BAUDRATE_TABLE,
|
||||
"xl330-m288": X_SERIES_BAUDRATE_TABLE,
|
||||
"xl430-w250": X_SERIES_BAUDRATE_TABLE,
|
||||
"xm430-w350": X_SERIES_BAUDRATE_TABLE,
|
||||
"xm540-w270": X_SERIES_BAUDRATE_TABLE,
|
||||
}
|
||||
|
||||
NUM_READ_RETRY = 10
|
||||
NUM_WRITE_RETRY = 10
|
||||
|
||||
|
||||
def convert_indices_to_baudrates(values: np.ndarray | list[int], models: list[str]):
|
||||
assert len(values) == len(models)
|
||||
for i in range(len(values)):
|
||||
model = models[i]
|
||||
index = values[i]
|
||||
values[i] = MODEL_BAUDRATE_TABLE[model][index]
|
||||
return values
|
||||
|
||||
|
||||
def convert_baudrates_to_indices(values: np.ndarray | list[int], models: list[str]):
|
||||
assert len(values) == len(models)
|
||||
for i in range(len(values)):
|
||||
model = models[i]
|
||||
brate = values[i]
|
||||
table_values = list(MODEL_BAUDRATE_TABLE[model].values())
|
||||
table_keys = list(MODEL_BAUDRATE_TABLE[model].keys())
|
||||
values[i] = table_keys[table_values.index(brate)]
|
||||
return values
|
||||
|
||||
|
||||
def convert_to_bytes(value, bytes):
|
||||
# Note: No need to convert back into unsigned int, since this byte preprocessing
|
||||
# already handles it for us.
|
||||
if bytes == 1:
|
||||
data = [
|
||||
DXL_LOBYTE(DXL_LOWORD(value)),
|
||||
]
|
||||
elif bytes == 2:
|
||||
data = [
|
||||
DXL_LOBYTE(DXL_LOWORD(value)),
|
||||
DXL_HIBYTE(DXL_LOWORD(value)),
|
||||
]
|
||||
elif bytes == 4:
|
||||
data = [
|
||||
DXL_LOBYTE(DXL_LOWORD(value)),
|
||||
DXL_HIBYTE(DXL_LOWORD(value)),
|
||||
DXL_LOBYTE(DXL_HIWORD(value)),
|
||||
DXL_HIBYTE(DXL_HIWORD(value)),
|
||||
]
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
|
||||
f"{bytes} is provided instead."
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
def get_group_sync_key(data_name, motor_names):
|
||||
@@ -233,6 +312,7 @@ class DynamixelMotorsBus:
|
||||
port: str,
|
||||
motors: dict[str, tuple[int, str]],
|
||||
extra_model_control_table: dict[str, list[tuple]] | None = None,
|
||||
extra_model_resolution: dict[str, int] | None = None,
|
||||
):
|
||||
self.port = port
|
||||
self.motors = motors
|
||||
@@ -241,6 +321,10 @@ class DynamixelMotorsBus:
|
||||
if extra_model_control_table:
|
||||
self.model_ctrl_table.update(extra_model_control_table)
|
||||
|
||||
self.model_resolution = deepcopy(MODEL_RESOLUTION)
|
||||
if extra_model_resolution:
|
||||
self.model_resolution.update(extra_model_resolution)
|
||||
|
||||
self.port_handler = None
|
||||
self.packet_handler = None
|
||||
self.calibration = None
|
||||
@@ -268,52 +352,279 @@ class DynamixelMotorsBus:
|
||||
)
|
||||
raise
|
||||
|
||||
self.port_handler.setBaudRate(BAUD_RATE)
|
||||
self.port_handler.setPacketTimeoutMillis(TIMEOUT_MS)
|
||||
|
||||
# Allow to read and write
|
||||
self.is_connected = True
|
||||
|
||||
self.port_handler.setPacketTimeoutMillis(TIMEOUT_MS)
|
||||
|
||||
# Set expected baud rate for the bus
|
||||
self.set_bus_baudrate(BAUDRATE)
|
||||
|
||||
if not self.are_motors_configured():
|
||||
print(
|
||||
r"/!\ First, verify that all the cables are connected the proper way. If you detect an issue, before making any modification, unplug the power cord to not damage the motors. Rewire correctly. Then plug the power again and relaunch the script."
|
||||
)
|
||||
print(
|
||||
r"/!\ Secondly, if the cables connection look correct and it is the first time that you use these motors, follow these manual steps to configure them."
|
||||
)
|
||||
input("Press Enter to configure your motors...")
|
||||
print()
|
||||
self.configure_motors()
|
||||
|
||||
def reconnect(self):
|
||||
self.port_handler = PortHandler(self.port)
|
||||
self.packet_handler = PacketHandler(PROTOCOL_VERSION)
|
||||
if not self.port_handler.openPort():
|
||||
raise OSError(f"Failed to open port '{self.port}'.")
|
||||
self.is_connected = True
|
||||
|
||||
def are_motors_configured(self):
|
||||
try:
|
||||
return (self.motor_indices == self.read("ID")).all()
|
||||
except ConnectionError as e:
|
||||
print(e)
|
||||
return False
|
||||
|
||||
def configure_motors(self):
|
||||
# TODO(rcadene): This script assumes motors follow the X_SERIES baudrates
|
||||
|
||||
print("Scanning all baudrates and motor indices")
|
||||
all_baudrates = set(X_SERIES_BAUDRATE_TABLE.values())
|
||||
ids_per_baudrate = {}
|
||||
for baudrate in all_baudrates:
|
||||
self.set_bus_baudrate(baudrate)
|
||||
present_ids = self.find_motor_indices()
|
||||
if len(present_ids) > 0:
|
||||
ids_per_baudrate[baudrate] = present_ids
|
||||
print(f"Motor indices detected: {ids_per_baudrate}")
|
||||
print()
|
||||
|
||||
possible_baudrates = list(ids_per_baudrate.keys())
|
||||
possible_ids = list({idx for sublist in ids_per_baudrate.values() for idx in sublist})
|
||||
untaken_ids = list(set(range(MAX_ID_RANGE)) - set(possible_ids) - set(self.motor_indices))
|
||||
|
||||
# Connect successively one motor to the chain and write a unique random index for each
|
||||
for i in range(len(self.motors)):
|
||||
self.disconnect()
|
||||
print("1. Unplug the power cord")
|
||||
print(
|
||||
f"2. Plug/unplug minimal number of cables to only have the first {i+1} motor(s) ({self.motor_names[:i+1]}) connected."
|
||||
)
|
||||
print("3. Re-plug the power cord.")
|
||||
input("Press Enter to continue...")
|
||||
print()
|
||||
self.reconnect()
|
||||
|
||||
if i > 0:
|
||||
try:
|
||||
self._read_with_motor_ids(self.motor_models, untaken_ids[:i], "ID")
|
||||
except ConnectionError:
|
||||
print(f"Failed to read from {untaken_ids[:i+1]}. Make sure the power cord is plugged in.")
|
||||
input("Press Enter to continue...")
|
||||
print()
|
||||
self.reconnect()
|
||||
|
||||
print("Scanning possible baudrates and motor indices")
|
||||
motor_found = False
|
||||
for baudrate in possible_baudrates:
|
||||
self.set_bus_baudrate(baudrate)
|
||||
present_ids = self.find_motor_indices(possible_ids)
|
||||
if len(present_ids) == 1:
|
||||
present_idx = present_ids[0]
|
||||
print(f"Detected motor with index {present_idx}")
|
||||
|
||||
if baudrate != BAUDRATE:
|
||||
print(f"Setting its baudrate to {BAUDRATE}")
|
||||
baudrate_idx = list(X_SERIES_BAUDRATE_TABLE.values()).index(BAUDRATE)
|
||||
|
||||
# The write can fail, so we allow retries
|
||||
for _ in range(NUM_WRITE_RETRY):
|
||||
self._write_with_motor_ids(
|
||||
self.motor_models, present_idx, "Baud_Rate", baudrate_idx
|
||||
)
|
||||
time.sleep(0.5)
|
||||
self.set_bus_baudrate(BAUDRATE)
|
||||
try:
|
||||
present_baudrate_idx = self._read_with_motor_ids(
|
||||
self.motor_models, present_idx, "Baud_Rate"
|
||||
)
|
||||
except ConnectionError:
|
||||
print("Failed to write baudrate. Retrying.")
|
||||
self.set_bus_baudrate(baudrate)
|
||||
continue
|
||||
break
|
||||
else:
|
||||
raise
|
||||
|
||||
if present_baudrate_idx != baudrate_idx:
|
||||
raise OSError("Failed to write baudrate.")
|
||||
|
||||
print(f"Setting its index to a temporary untaken index ({untaken_ids[i]})")
|
||||
self._write_with_motor_ids(self.motor_models, present_idx, "ID", untaken_ids[i])
|
||||
|
||||
present_idx = self._read_with_motor_ids(self.motor_models, untaken_ids[i], "ID")
|
||||
if present_idx != untaken_ids[i]:
|
||||
raise OSError("Failed to write index.")
|
||||
|
||||
motor_found = True
|
||||
break
|
||||
elif len(present_ids) > 1:
|
||||
raise OSError(f"More than one motor detected ({present_ids}), but only one was expected.")
|
||||
|
||||
if not motor_found:
|
||||
raise OSError(
|
||||
"No motor found, but one new motor expected. Verify power cord is plugged in and retry."
|
||||
)
|
||||
print()
|
||||
|
||||
print(f"Setting expected motor indices: {self.motor_indices}")
|
||||
self.set_bus_baudrate(BAUDRATE)
|
||||
self._write_with_motor_ids(
|
||||
self.motor_models, untaken_ids[: len(self.motors)], "ID", self.motor_indices
|
||||
)
|
||||
print()
|
||||
|
||||
if (self.read("ID") != self.motor_indices).any():
|
||||
raise OSError("Failed to write motors indices.")
|
||||
|
||||
print("Configuration is done!")
|
||||
|
||||
def find_motor_indices(self, possible_ids=None):
|
||||
if possible_ids is None:
|
||||
possible_ids = range(MAX_ID_RANGE)
|
||||
|
||||
indices = []
|
||||
for idx in tqdm.tqdm(possible_ids):
|
||||
try:
|
||||
present_idx = self._read_with_motor_ids(self.motor_models, [idx], "ID")[0]
|
||||
except ConnectionError:
|
||||
continue
|
||||
|
||||
if idx != present_idx:
|
||||
# sanity check
|
||||
raise OSError(
|
||||
"Motor index used to communicate through the bus is not the same as the one present in the motor memory. The motor memory might be damaged."
|
||||
)
|
||||
indices.append(idx)
|
||||
|
||||
return indices
|
||||
|
||||
def set_bus_baudrate(self, baudrate):
|
||||
present_bus_baudrate = self.port_handler.getBaudRate()
|
||||
if present_bus_baudrate != baudrate:
|
||||
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
|
||||
self.port_handler.setBaudRate(baudrate)
|
||||
|
||||
if self.port_handler.getBaudRate() != baudrate:
|
||||
raise OSError("Failed to write bus baud rate.")
|
||||
|
||||
@property
|
||||
def motor_names(self) -> list[int]:
|
||||
def motor_names(self) -> list[str]:
|
||||
return list(self.motors.keys())
|
||||
|
||||
@property
|
||||
def motor_models(self) -> list[str]:
|
||||
return [model for _, model in self.motors.values()]
|
||||
|
||||
@property
|
||||
def motor_indices(self) -> list[int]:
|
||||
return [idx for idx, _ in self.motors.values()]
|
||||
|
||||
def set_calibration(self, calibration: dict[str, tuple[int, bool]]):
|
||||
self.calibration = calibration
|
||||
|
||||
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
if not self.calibration:
|
||||
return values
|
||||
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 centered degree range [-180.0, 180.0[
|
||||
|
||||
Joints values are original in [0, 2**32[ (unsigned int32). Each motor are expected to complete a full rotation
|
||||
when given a goal position that is + or - their resolution. For instance, dynamixel xl330-m077 have a resolution of 4096, and
|
||||
at any position in their original range, let's say the position 56734, they complete a full rotation clockwise by moving to 60830,
|
||||
or anticlockwise by moving to 42638. The position in the original range is arbitrary and might change a lot between each motor.
|
||||
To harmonize between motors of the same model, different robots, or even models of different brands, we propose to work
|
||||
in the centered degree range [-180, 180[. This function first applies the pre-computed calibration to convert
|
||||
from [0, 2**32[ to [-2048, 2048[, then divide by 2048.
|
||||
"""
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
|
||||
# Convert from unsigned int32 original range [0, 2**32[ to centered signed int32 range [-2**31, 2**31[
|
||||
values = values.astype(np.int32)
|
||||
|
||||
for i, name in enumerate(motor_names):
|
||||
homing_offset, drive_mode = self.calibration[name]
|
||||
|
||||
if values[i] is not None:
|
||||
if drive_mode:
|
||||
values[i] *= -1
|
||||
values[i] += homing_offset
|
||||
# Update direction of rotation of the motor to match between leader and follower. In fact, the motor of the leader for a given joint
|
||||
# can be assembled in an opposite direction in term of rotation than the motor of the follower on the same joint.
|
||||
if drive_mode:
|
||||
values[i] *= -1
|
||||
|
||||
# Convert from range [-2**31, 2**31[ to centered resolution range [-resolution, resolution[ (e.g. [-2048, 2048[)
|
||||
values[i] += homing_offset
|
||||
|
||||
# Convert from range [-resolution, resolution[ to the universal float32 centered degree range [-180, 180[
|
||||
values = values.astype(np.float32)
|
||||
for i, name in enumerate(motor_names):
|
||||
_, model = self.motors[name]
|
||||
resolution = self.model_resolution[model]
|
||||
values[i] = values[i] / (resolution // 2) * 180
|
||||
|
||||
return values
|
||||
|
||||
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
if not self.calibration:
|
||||
return values
|
||||
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
|
||||
# Convert from the universal float32 centered degree range [-180, 180[ to centered resolution range [-resolution, resolution[
|
||||
for i, name in enumerate(motor_names):
|
||||
_, model = self.motors[name]
|
||||
resolution = self.model_resolution[model]
|
||||
|
||||
values[i] = values[i] / 180 * (resolution // 2)
|
||||
|
||||
values = np.round(values).astype(np.int32)
|
||||
|
||||
# Convert from range [-resolution, resolution[ to centered signed int32 range [-2**31, 2**31[
|
||||
for i, name in enumerate(motor_names):
|
||||
homing_offset, drive_mode = self.calibration[name]
|
||||
values[i] -= homing_offset
|
||||
|
||||
if values[i] is not None:
|
||||
values[i] -= homing_offset
|
||||
if drive_mode:
|
||||
values[i] *= -1
|
||||
# Update direction of rotation of the motor that was matching between leader and follower to their original direction.
|
||||
# In fact, the motor of the leader for a given joint can be assembled in an opposite direction in term of rotation
|
||||
# than the motor of the follower on the same joint.
|
||||
if drive_mode:
|
||||
values[i] *= -1
|
||||
|
||||
return values
|
||||
|
||||
def _read_with_motor_ids(self, motor_models, motor_ids, data_name):
|
||||
return_list = True
|
||||
if not isinstance(motor_ids, list):
|
||||
return_list = False
|
||||
motor_ids = [motor_ids]
|
||||
|
||||
assert_same_address(self.model_ctrl_table, self.motor_models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
|
||||
group = GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes)
|
||||
for idx in motor_ids:
|
||||
group.addParam(idx)
|
||||
|
||||
comm = group.txRxPacket()
|
||||
if comm != COMM_SUCCESS:
|
||||
raise ConnectionError(
|
||||
f"Read failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: "
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
values = []
|
||||
for idx in motor_ids:
|
||||
value = group.getData(idx, addr, bytes)
|
||||
values.append(value)
|
||||
|
||||
if return_list:
|
||||
return values
|
||||
else:
|
||||
return values[0]
|
||||
|
||||
def read(self, data_name, motor_names: str | list[str] | None = None):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
@@ -367,7 +678,7 @@ class DynamixelMotorsBus:
|
||||
if data_name in CONVERT_UINT32_TO_INT32_REQUIRED:
|
||||
values = values.astype(np.int32)
|
||||
|
||||
if data_name in CALIBRATION_REQUIRED:
|
||||
if data_name in CALIBRATION_REQUIRED and self.calibration is not None:
|
||||
values = self.apply_calibration(values, motor_names)
|
||||
|
||||
# log the number of seconds it took to read the data from the motors
|
||||
@@ -380,6 +691,26 @@ class DynamixelMotorsBus:
|
||||
|
||||
return values
|
||||
|
||||
def _write_with_motor_ids(self, motor_models, motor_ids, data_name, values):
|
||||
if not isinstance(motor_ids, list):
|
||||
motor_ids = [motor_ids]
|
||||
if not isinstance(values, list):
|
||||
values = [values]
|
||||
|
||||
assert_same_address(self.model_ctrl_table, motor_models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
|
||||
group = GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes)
|
||||
for idx, value in zip(motor_ids, values, strict=True):
|
||||
data = convert_to_bytes(value, bytes)
|
||||
group.addParam(idx, data)
|
||||
|
||||
comm = group.txPacket()
|
||||
if comm != COMM_SUCCESS:
|
||||
raise ConnectionError(
|
||||
f"Write failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: "
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
@@ -406,7 +737,7 @@ class DynamixelMotorsBus:
|
||||
motor_ids.append(motor_idx)
|
||||
models.append(model)
|
||||
|
||||
if data_name in CALIBRATION_REQUIRED:
|
||||
if data_name in CALIBRATION_REQUIRED and self.calibration is not None:
|
||||
values = self.revert_calibration(values, motor_names)
|
||||
|
||||
values = values.tolist()
|
||||
@@ -422,30 +753,7 @@ class DynamixelMotorsBus:
|
||||
)
|
||||
|
||||
for idx, value in zip(motor_ids, values, strict=True):
|
||||
# Note: No need to convert back into unsigned int, since this byte preprocessing
|
||||
# already handles it for us.
|
||||
if bytes == 1:
|
||||
data = [
|
||||
DXL_LOBYTE(DXL_LOWORD(value)),
|
||||
]
|
||||
elif bytes == 2:
|
||||
data = [
|
||||
DXL_LOBYTE(DXL_LOWORD(value)),
|
||||
DXL_HIBYTE(DXL_LOWORD(value)),
|
||||
]
|
||||
elif bytes == 4:
|
||||
data = [
|
||||
DXL_LOBYTE(DXL_LOWORD(value)),
|
||||
DXL_HIBYTE(DXL_LOWORD(value)),
|
||||
DXL_LOBYTE(DXL_HIWORD(value)),
|
||||
DXL_HIBYTE(DXL_HIWORD(value)),
|
||||
]
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
|
||||
f"{bytes} is provided instead."
|
||||
)
|
||||
|
||||
data = convert_to_bytes(value, bytes)
|
||||
if init_group:
|
||||
self.group_writers[group_key].addParam(idx, data)
|
||||
else:
|
||||
|
||||
@@ -1,46 +1,7 @@
|
||||
def make_robot(name):
|
||||
if name == "koch":
|
||||
# TODO(rcadene): Add configurable robot from command line and yaml config
|
||||
# TODO(rcadene): Add example with and without cameras
|
||||
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
|
||||
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
|
||||
from lerobot.common.robot_devices.robots.koch import KochRobot
|
||||
import hydra
|
||||
from omegaconf import DictConfig
|
||||
|
||||
robot = KochRobot(
|
||||
leader_arms={
|
||||
"main": DynamixelMotorsBus(
|
||||
port="/dev/tty.usbmodem575E0031751",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": (1, "xl330-m077"),
|
||||
"shoulder_lift": (2, "xl330-m077"),
|
||||
"elbow_flex": (3, "xl330-m077"),
|
||||
"wrist_flex": (4, "xl330-m077"),
|
||||
"wrist_roll": (5, "xl330-m077"),
|
||||
"gripper": (6, "xl330-m077"),
|
||||
},
|
||||
),
|
||||
},
|
||||
follower_arms={
|
||||
"main": DynamixelMotorsBus(
|
||||
port="/dev/tty.usbmodem575E0032081",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": (1, "xl430-w250"),
|
||||
"shoulder_lift": (2, "xl430-w250"),
|
||||
"elbow_flex": (3, "xl330-m288"),
|
||||
"wrist_flex": (4, "xl330-m288"),
|
||||
"wrist_roll": (5, "xl330-m288"),
|
||||
"gripper": (6, "xl330-m288"),
|
||||
},
|
||||
),
|
||||
},
|
||||
cameras={
|
||||
"laptop": OpenCVCamera(0, fps=30, width=640, height=480),
|
||||
"phone": OpenCVCamera(1, fps=30, width=640, height=480),
|
||||
},
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Robot '{name}' not found.")
|
||||
|
||||
def make_robot(cfg: DictConfig):
|
||||
robot = hydra.utils.instantiate(cfg)
|
||||
return robot
|
||||
|
||||
@@ -29,9 +29,12 @@ URL_90_DEGREE_POSITION = {
|
||||
# Calibration logic
|
||||
########################################################################
|
||||
|
||||
# In range ]-2048, 2048[
|
||||
TARGET_HORIZONTAL_POSITION = np.array([0, -1024, 1024, 0, -1024, 0])
|
||||
TARGET_90_DEGREE_POSITION = np.array([1024, 0, 0, 1024, 0, -1024])
|
||||
GRIPPER_OPEN = np.array([-400])
|
||||
|
||||
# In range ]-180, 180[
|
||||
GRIPPER_OPEN = np.array([-35.156])
|
||||
|
||||
|
||||
def apply_homing_offset(values: np.array, homing_offset: np.array) -> np.array:
|
||||
@@ -330,7 +333,9 @@ class KochRobot:
|
||||
|
||||
# Connect the arms
|
||||
for name in self.follower_arms:
|
||||
print(f"Connecting {name} follower arm.")
|
||||
self.follower_arms[name].connect()
|
||||
print(f"Connecting {name} leader arm.")
|
||||
self.leader_arms[name].connect()
|
||||
|
||||
# Reset the arms and load or run calibration
|
||||
@@ -357,6 +362,18 @@ class KochRobot:
|
||||
for name in self.leader_arms:
|
||||
self.leader_arms[name].set_calibration(calibration[f"leader_{name}"])
|
||||
|
||||
for name in self.leader_arms:
|
||||
values = self.leader_arms[name].read("Present_Position")
|
||||
if (values < -180).any() or (values >= 180).any():
|
||||
raise ValueError(
|
||||
f"At least one of the motor of the {name} leader arm has a joint value outside of its centered degree range of ]-180, 180[."
|
||||
'This "jump of range" can be caused by a hardware issue, or you might have unexpectedly completed a full rotation of the motor '
|
||||
"during manipulation or transportation of your robot. "
|
||||
f"The values and motors: {values} {self.leader_arms[name].motor_names}.\n"
|
||||
"Rotate the arm to fit the range ]-180, 180[ and relaunch the script, or recalibrate all motors by setting a different "
|
||||
"`calibration_path` during the instatiation of your robot."
|
||||
)
|
||||
|
||||
# Set better PID values to close the gap between recored states and actions
|
||||
# TODO(rcadene): Implement an automatic procedure to set optimial PID values for each motor
|
||||
for name in self.follower_arms:
|
||||
@@ -500,11 +517,7 @@ class KochRobot:
|
||||
obs_dict = {}
|
||||
obs_dict["observation.state"] = torch.from_numpy(state)
|
||||
for name in self.cameras:
|
||||
# Convert to pytorch format: channel first and float32 in [0,1]
|
||||
img = torch.from_numpy(images[name])
|
||||
img = img.type(torch.float32) / 255
|
||||
img = img.permute(2, 0, 1).contiguous()
|
||||
obs_dict[f"observation.images.{name}"] = img
|
||||
obs_dict[f"observation.images.{name}"] = torch.from_numpy(images[name])
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: torch.Tensor):
|
||||
|
||||
@@ -107,13 +107,13 @@ training:
|
||||
min_max: [0.8, 1.2]
|
||||
saturation:
|
||||
weight: 1
|
||||
min_max: [0.5, 1.5]
|
||||
min_max: [0.7, 1.3]
|
||||
hue:
|
||||
weight: 1
|
||||
min_max: [-0.05, 0.05]
|
||||
min_max: [-0.02, 0.02]
|
||||
sharpness:
|
||||
weight: 1
|
||||
min_max: [0.8, 1.2]
|
||||
min_max: [0.1, 1.9]
|
||||
|
||||
eval:
|
||||
n_episodes: 1
|
||||
|
||||
102
lerobot/configs/policy/act_koch_real_bigger.yaml
Normal file
102
lerobot/configs/policy/act_koch_real_bigger.yaml
Normal file
@@ -0,0 +1,102 @@
|
||||
# @package _global_
|
||||
|
||||
# Use `act_koch_real.yaml` to train on real-world datasets collected on Alexander Koch's robots.
|
||||
# Compared to `act.yaml`, it contains 2 cameras (i.e. laptop, phone) instead of 1 camera (i.e. top).
|
||||
# Also, `training.eval_freq` is set to -1. This config is used to evaluate checkpoints at a certain frequency of training steps.
|
||||
# When it is set to -1, it deactivates evaluation. This is because real-world evaluation is done through our `control_robot.py` script.
|
||||
# Look at the documentation in header of `control_robot.py` for more information on how to collect data , train and evaluate a policy.
|
||||
#
|
||||
# Example of usage for training:
|
||||
# ```bash
|
||||
# python lerobot/scripts/train.py \
|
||||
# policy=act_koch_real \
|
||||
# env=koch_real
|
||||
# ```
|
||||
|
||||
seed: 1000
|
||||
dataset_repo_id: lerobot/koch_pick_place_lego
|
||||
|
||||
override_dataset_stats:
|
||||
observation.images.laptop:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.phone:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 160000
|
||||
online_steps: 0
|
||||
eval_freq: -1
|
||||
save_freq: 5000
|
||||
log_freq: 100
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 16
|
||||
lr: 1e-4
|
||||
lr_backbone: 1e-4
|
||||
weight_decay: 1e-4
|
||||
grad_clip_norm: 10
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} + 1 / ${fps} for i in range(${policy.chunk_size})]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
name: act
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 1
|
||||
chunk_size: 100
|
||||
n_action_steps: 100
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.images.laptop: [3, 480, 640]
|
||||
observation.images.phone: [3, 480, 640]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.images.laptop: mean_std
|
||||
observation.images.phone: mean_std
|
||||
observation.state: mean_std
|
||||
output_normalization_modes:
|
||||
action: mean_std
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: resnet18
|
||||
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||
replace_final_stride_with_dilation: false
|
||||
# Transformer layers.
|
||||
pre_norm: false
|
||||
dim_model: 512
|
||||
n_heads: 8
|
||||
dim_feedforward: 3200
|
||||
feedforward_activation: relu
|
||||
n_encoder_layers: 4
|
||||
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||
n_decoder_layers: 4
|
||||
# VAE.
|
||||
use_vae: true
|
||||
latent_dim: 32
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_momentum: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
kl_weight: 10.0
|
||||
38
lerobot/configs/robot/koch.yaml
Normal file
38
lerobot/configs/robot/koch.yaml
Normal file
@@ -0,0 +1,38 @@
|
||||
_target_: lerobot.common.robot_devices.robots.koch.KochRobot
|
||||
leader_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/tty.usbmodem575E0031751
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "xl330-m077"]
|
||||
shoulder_lift: [2, "xl330-m077"]
|
||||
elbow_flex: [3, "xl330-m077"]
|
||||
wrist_flex: [4, "xl330-m077"]
|
||||
wrist_roll: [5, "xl330-m077"]
|
||||
gripper: [6, "xl330-m077"]
|
||||
follower_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/tty.usbmodem575E0032081
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "xl430-w250"]
|
||||
shoulder_lift: [2, "xl430-w250"]
|
||||
elbow_flex: [3, "xl330-m288"]
|
||||
wrist_flex: [4, "xl330-m288"]
|
||||
wrist_roll: [5, "xl330-m288"]
|
||||
gripper: [6, "xl330-m288"]
|
||||
cameras:
|
||||
laptop:
|
||||
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||
camera_index: 0
|
||||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
phone:
|
||||
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||
camera_index: 1
|
||||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
@@ -4,6 +4,9 @@ Examples of usage:
|
||||
- Unlimited teleoperation at highest frequency (~200 Hz is expected), to exit with CTRL+C:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate
|
||||
|
||||
# Remove the cameras from the robot definition. They are not used in 'teleoperate' anyway.
|
||||
python lerobot/scripts/control_robot.py teleoperate '~cameras'
|
||||
```
|
||||
|
||||
- Unlimited teleoperation at a limited frequency of 30 Hz, to simulate data recording frequency:
|
||||
@@ -14,7 +17,7 @@ python lerobot/scripts/control_robot.py teleoperate \
|
||||
|
||||
- Record one episode in order to test replay:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record_dataset \
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--fps 30 \
|
||||
--root tmp/data \
|
||||
--repo-id $USER/koch_test \
|
||||
@@ -32,7 +35,7 @@ python lerobot/scripts/visualize_dataset.py \
|
||||
|
||||
- Replay this test episode:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py replay_episode \
|
||||
python lerobot/scripts/control_robot.py replay \
|
||||
--fps 30 \
|
||||
--root tmp/data \
|
||||
--repo-id $USER/koch_test \
|
||||
@@ -42,12 +45,11 @@ python lerobot/scripts/control_robot.py replay_episode \
|
||||
- Record a full dataset in order to train a policy, with 2 seconds of warmup,
|
||||
30 seconds of recording for each episode, and 10 seconds to reset the environment in between episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record_dataset \
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id $USER/koch_pick_place_lego \
|
||||
--num-episodes 50 \
|
||||
--run-compute-stats 1 \
|
||||
--warmup-time-s 2 \
|
||||
--episode-time-s 30 \
|
||||
--reset-time-s 10
|
||||
@@ -74,7 +76,14 @@ DATA_DIR=data python lerobot/scripts/train.py \
|
||||
|
||||
- Run the pretrained policy on the robot:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py run_policy \
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id $USER/eval_act_koch_real \
|
||||
--num-episodes 10 \
|
||||
--warmup-time-s 2 \
|
||||
--episode-time-s 30 \
|
||||
--reset-time-s 10
|
||||
-p outputs/train/act_koch_real/checkpoints/080000/pretrained_model
|
||||
```
|
||||
"""
|
||||
@@ -88,8 +97,10 @@ import platform
|
||||
import shutil
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import tqdm
|
||||
from huggingface_hub import create_branch
|
||||
@@ -179,7 +190,8 @@ def log_control_info(robot, dt_s, episode_index=None, frame_index=None, fps=None
|
||||
logging.info(info_str)
|
||||
|
||||
|
||||
def get_is_headless():
|
||||
@cache
|
||||
def is_headless():
|
||||
if platform.system() == "Linux":
|
||||
display = os.environ.get("DISPLAY")
|
||||
if display is None or display == "":
|
||||
@@ -213,8 +225,10 @@ def teleoperate(robot: Robot, fps: int | None = None, teleop_time_s: float | Non
|
||||
break
|
||||
|
||||
|
||||
def record_dataset(
|
||||
def record(
|
||||
robot: Robot,
|
||||
policy: torch.nn.Module | None = None,
|
||||
hydra_cfg: DictConfig | None = None,
|
||||
fps: int | None = None,
|
||||
root="data",
|
||||
repo_id="lerobot/debug",
|
||||
@@ -255,32 +269,10 @@ def record_dataset(
|
||||
else:
|
||||
episode_index = 0
|
||||
|
||||
is_headless = get_is_headless()
|
||||
|
||||
# Execute a few seconds without recording data, to give times
|
||||
# to the robot devices to connect and start synchronizing.
|
||||
timestamp = 0
|
||||
start_time = time.perf_counter()
|
||||
is_warmup_print = False
|
||||
while timestamp < warmup_time_s:
|
||||
if not is_warmup_print:
|
||||
logging.info("Warming up (no data recording)")
|
||||
os.system('say "Warmup" &')
|
||||
is_warmup_print = True
|
||||
|
||||
now = time.perf_counter()
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
|
||||
if not is_headless:
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
|
||||
dt_s = time.perf_counter() - now
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - now
|
||||
log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
timestamp = time.perf_counter() - start_time
|
||||
if is_headless():
|
||||
logging.info(
|
||||
"Headless environment detected. Display cameras on screen and keyboard inputs will not be available."
|
||||
)
|
||||
|
||||
# Allow to exit early while recording an episode or resetting the environment,
|
||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||
@@ -290,9 +282,7 @@ def record_dataset(
|
||||
stop_recording = False
|
||||
|
||||
# Only import pynput if not in a headless environment
|
||||
if is_headless:
|
||||
logging.info("Headless environment detected. Keyboard input will not be available.")
|
||||
else:
|
||||
if not is_headless():
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
@@ -315,6 +305,53 @@ def record_dataset(
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
|
||||
# Load policy if any
|
||||
if policy is not None:
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(hydra_cfg.device, log=True)
|
||||
|
||||
policy.eval()
|
||||
policy.to(device)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
set_global_seed(hydra_cfg.seed)
|
||||
|
||||
# override fps using policy fps
|
||||
fps = hydra_cfg.env.fps
|
||||
|
||||
# Execute a few seconds without recording data, to give times
|
||||
# to the robot devices to connect and start synchronizing.
|
||||
timestamp = 0
|
||||
start_time = time.perf_counter()
|
||||
is_warmup_print = False
|
||||
while timestamp < warmup_time_s:
|
||||
if not is_warmup_print:
|
||||
logging.info("Warming up (no data recording)")
|
||||
os.system('say "Warmup" &')
|
||||
is_warmup_print = True
|
||||
|
||||
now = time.perf_counter()
|
||||
|
||||
if policy is None:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
else:
|
||||
observation = robot.capture_observation()
|
||||
|
||||
if not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
|
||||
dt_s = time.perf_counter() - now
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - now
|
||||
log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
timestamp = time.perf_counter() - start_time
|
||||
|
||||
# Save images using threads to reach high fps (30 and more)
|
||||
# Using `with` to exist smoothly if an execption is raised.
|
||||
# Using only 4 worker threads to avoid blocking the main thread.
|
||||
@@ -330,7 +367,11 @@ def record_dataset(
|
||||
start_time = time.perf_counter()
|
||||
while timestamp < episode_time_s:
|
||||
now = time.perf_counter()
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
|
||||
if policy is None:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
else:
|
||||
observation = robot.capture_observation()
|
||||
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
not_image_keys = [key for key in observation if "image" not in key]
|
||||
@@ -342,11 +383,44 @@ def record_dataset(
|
||||
)
|
||||
]
|
||||
|
||||
if not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
|
||||
for key in not_image_keys:
|
||||
if key not in ep_dict:
|
||||
ep_dict[key] = []
|
||||
ep_dict[key].append(observation[key])
|
||||
|
||||
if policy is not None:
|
||||
with (
|
||||
torch.inference_mode(),
|
||||
torch.autocast(device_type=device.type)
|
||||
if device.type == "cuda" and hydra_cfg.use_amp
|
||||
else nullcontext(),
|
||||
):
|
||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||
for name in observation:
|
||||
if "image" in name:
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||
observation[name] = observation[name].unsqueeze(0)
|
||||
|
||||
if device.type == "mps":
|
||||
for name in observation:
|
||||
observation[name] = observation[name].to(device)
|
||||
|
||||
action = policy.select_action(observation)
|
||||
|
||||
# remove batch dimension
|
||||
action = action.squeeze(0)
|
||||
action = action.to("cpu")
|
||||
|
||||
robot.send_action(action)
|
||||
action = {"action": action}
|
||||
|
||||
for key in action:
|
||||
if key not in ep_dict:
|
||||
ep_dict[key] = []
|
||||
@@ -434,7 +508,7 @@ def record_dataset(
|
||||
if is_last_episode:
|
||||
logging.info("Done recording")
|
||||
os.system('say "Done recording"')
|
||||
if not is_headless:
|
||||
if not is_headless():
|
||||
listener.stop()
|
||||
|
||||
logging.info("Waiting for threads writing the images on disk to terminate...")
|
||||
@@ -444,6 +518,10 @@ def record_dataset(
|
||||
pass
|
||||
break
|
||||
|
||||
robot.disconnect()
|
||||
if not is_headless():
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
num_episodes = episode_index
|
||||
|
||||
logging.info("Encoding videos")
|
||||
@@ -495,6 +573,7 @@ def record_dataset(
|
||||
stats = compute_stats(lerobot_dataset)
|
||||
lerobot_dataset.stats = stats
|
||||
else:
|
||||
stats = {}
|
||||
logging.info("Skipping computation of the dataset statistrics")
|
||||
|
||||
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
||||
@@ -516,7 +595,7 @@ def record_dataset(
|
||||
return lerobot_dataset
|
||||
|
||||
|
||||
def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug"):
|
||||
def replay(robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug"):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
local_dir = Path(root) / repo_id
|
||||
if not local_dir.exists():
|
||||
@@ -546,61 +625,6 @@ def replay_episode(robot: Robot, episode: int, fps: int | None = None, root="dat
|
||||
log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
|
||||
def run_policy(robot: Robot, policy: torch.nn.Module, hydra_cfg: DictConfig, run_time_s: float | None = None):
|
||||
# TODO(rcadene): Add option to record eval dataset and logs
|
||||
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(hydra_cfg.device, log=True)
|
||||
|
||||
policy.eval()
|
||||
policy.to(device)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
set_global_seed(hydra_cfg.seed)
|
||||
|
||||
fps = hydra_cfg.env.fps
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
start_time = time.perf_counter()
|
||||
while True:
|
||||
now = time.perf_counter()
|
||||
|
||||
observation = robot.capture_observation()
|
||||
|
||||
with (
|
||||
torch.inference_mode(),
|
||||
torch.autocast(device_type=device.type)
|
||||
if device.type == "cuda" and hydra_cfg.use_amp
|
||||
else nullcontext(),
|
||||
):
|
||||
# add batch dimension to 1
|
||||
for name in observation:
|
||||
observation[name] = observation[name].unsqueeze(0)
|
||||
|
||||
if device.type == "mps":
|
||||
for name in observation:
|
||||
observation[name] = observation[name].to(device)
|
||||
|
||||
action = policy.select_action(observation)
|
||||
|
||||
# remove batch dimension
|
||||
action = action.squeeze(0)
|
||||
|
||||
robot.send_action(action.to("cpu"))
|
||||
|
||||
dt_s = time.perf_counter() - now
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - now
|
||||
log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
if run_time_s is not None and time.perf_counter() - start_time > run_time_s:
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
subparsers = parser.add_subparsers(dest="mode", required=True)
|
||||
@@ -608,10 +632,15 @@ if __name__ == "__main__":
|
||||
# Set common options for all the subparsers
|
||||
base_parser = argparse.ArgumentParser(add_help=False)
|
||||
base_parser.add_argument(
|
||||
"--robot",
|
||||
"--robot-path",
|
||||
type=str,
|
||||
default="koch",
|
||||
help="Name of the robot provided to the `make_robot(name)` factory function.",
|
||||
default="lerobot/configs/robot/koch.yaml",
|
||||
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
|
||||
)
|
||||
base_parser.add_argument(
|
||||
"robot_overrides",
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
|
||||
parser_teleop = subparsers.add_parser("teleoperate", parents=[base_parser])
|
||||
@@ -619,7 +648,7 @@ if __name__ == "__main__":
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
)
|
||||
|
||||
parser_record = subparsers.add_parser("record_dataset", parents=[base_parser])
|
||||
parser_record = subparsers.add_parser("record", parents=[base_parser])
|
||||
parser_record.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
)
|
||||
@@ -678,8 +707,22 @@ if __name__ == "__main__":
|
||||
default=0,
|
||||
help="By default, data recording is resumed. When set to 1, delete the local directory and start data recording from scratch.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"-p",
|
||||
"--pretrained-policy-name-or-path",
|
||||
type=str,
|
||||
help=(
|
||||
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
||||
"saved using `Policy.save_pretrained`."
|
||||
),
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"overrides",
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
|
||||
parser_replay = subparsers.add_parser("replay_episode", parents=[base_parser])
|
||||
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
|
||||
parser_replay.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
)
|
||||
@@ -697,41 +740,38 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episode to replay.")
|
||||
|
||||
parser_policy = subparsers.add_parser("run_policy", parents=[base_parser])
|
||||
parser_policy.add_argument(
|
||||
"-p",
|
||||
"--pretrained-policy-name-or-path",
|
||||
type=str,
|
||||
help=(
|
||||
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
||||
"saved using `Policy.save_pretrained`."
|
||||
),
|
||||
)
|
||||
parser_policy.add_argument(
|
||||
"overrides",
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
init_logging()
|
||||
|
||||
control_mode = args.mode
|
||||
robot_name = args.robot
|
||||
robot_path = args.robot_path
|
||||
robot_overrides = args.robot_overrides
|
||||
kwargs = vars(args)
|
||||
del kwargs["mode"]
|
||||
del kwargs["robot"]
|
||||
del kwargs["robot_path"]
|
||||
del kwargs["robot_overrides"]
|
||||
|
||||
robot_cfg = init_hydra_config(robot_path, robot_overrides)
|
||||
robot = make_robot(robot_cfg)
|
||||
|
||||
robot = make_robot(robot_name)
|
||||
if control_mode == "teleoperate":
|
||||
teleoperate(robot, **kwargs)
|
||||
elif control_mode == "record_dataset":
|
||||
record_dataset(robot, **kwargs)
|
||||
elif control_mode == "replay_episode":
|
||||
replay_episode(robot, **kwargs)
|
||||
|
||||
elif control_mode == "run_policy":
|
||||
pretrained_policy_path = get_pretrained_policy_path(args.pretrained_policy_name_or_path)
|
||||
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", args.overrides)
|
||||
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
|
||||
run_policy(robot, policy, hydra_cfg)
|
||||
elif control_mode == "record":
|
||||
pretrained_policy_name_or_path = args.pretrained_policy_name_or_path
|
||||
overrides = args.overrides
|
||||
del kwargs["pretrained_policy_name_or_path"]
|
||||
del kwargs["overrides"]
|
||||
|
||||
policy_cfg = None
|
||||
if pretrained_policy_name_or_path is not None:
|
||||
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
|
||||
policy_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", overrides)
|
||||
policy = make_policy(hydra_cfg=policy_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
|
||||
record(robot, policy, policy_cfg, **kwargs)
|
||||
else:
|
||||
record(robot, **kwargs)
|
||||
|
||||
elif control_mode == "replay":
|
||||
replay(robot, **kwargs)
|
||||
|
||||
203
poetry.lock
generated
203
poetry.lock
generated
@@ -795,6 +795,7 @@ files = [
|
||||
{file = "dora_rs-0.3.5-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:01f811d0c6722f74743c153a7be0144686daeafa968c473e60f6b6c5dc8f5bff"},
|
||||
{file = "dora_rs-0.3.5-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:a36e97d31eeb66e6d5913130695d188ceee1248029961012a8b4f59fd3f58670"},
|
||||
{file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25d620123a733661dc740ef2b456601ddbaa69ae2b50d8141daa3c684bda385c"},
|
||||
{file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a9fdc4e73578bebb1c8d0f8bea2243a5a9e179f08c74d98576123b59b75e5cac"},
|
||||
{file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e65830634c58158557f0ab90e5d1f492bcbc6b74587b05825ba4c20b634dc1bd"},
|
||||
{file = "dora_rs-0.3.5-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c01f9ab8f93295341aeab2d606d484d9cff9d05f57581e2180433ec8e0d38307"},
|
||||
{file = "dora_rs-0.3.5-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:5d6d46a49a34cd7e4f74496a1089b9a1b78282c219a28d98fe031a763e92d530"},
|
||||
@@ -1536,20 +1537,6 @@ files = [
|
||||
{file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "intel-openmp"
|
||||
version = "2021.4.0"
|
||||
description = "Intel OpenMP* Runtime Library"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "intel_openmp-2021.4.0-py2.py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.whl", hash = "sha256:41c01e266a7fdb631a7609191709322da2bbf24b252ba763f125dd651bcc7675"},
|
||||
{file = "intel_openmp-2021.4.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:3b921236a38384e2016f0f3d65af6732cf2c12918087128a9163225451e776f2"},
|
||||
{file = "intel_openmp-2021.4.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:e2240ab8d01472fed04f3544a878cda5da16c26232b7ea1b59132dbfb48b186e"},
|
||||
{file = "intel_openmp-2021.4.0-py2.py3-none-win32.whl", hash = "sha256:6e863d8fd3d7e8ef389d52cf97a50fe2afe1a19247e8c0d168ce021546f96fc9"},
|
||||
{file = "intel_openmp-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:eef4c8bcc8acefd7f5cd3b9384dbf73d59e2c99fc56545712ded913f43c4a94f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jinja2"
|
||||
version = "3.1.4"
|
||||
@@ -1741,9 +1728,13 @@ files = [
|
||||
{file = "lxml-5.2.2-cp36-cp36m-win_amd64.whl", hash = "sha256:edcfa83e03370032a489430215c1e7783128808fd3e2e0a3225deee278585196"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:28bf95177400066596cdbcfc933312493799382879da504633d16cf60bba735b"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3a745cc98d504d5bd2c19b10c79c61c7c3df9222629f1b6210c0368177589fb8"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b590b39ef90c6b22ec0be925b211298e810b4856909c8ca60d27ffbca6c12e6"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b336b0416828022bfd5a2e3083e7f5ba54b96242159f83c7e3eebaec752f1716"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_28_aarch64.whl", hash = "sha256:c2faf60c583af0d135e853c86ac2735ce178f0e338a3c7f9ae8f622fd2eb788c"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:4bc6cb140a7a0ad1f7bc37e018d0ed690b7b6520ade518285dc3171f7a117905"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:7ff762670cada8e05b32bf1e4dc50b140790909caa8303cfddc4d702b71ea184"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:57f0a0bbc9868e10ebe874e9f129d2917750adf008fe7b9c1598c0fbbfdde6a6"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:a6d2092797b388342c1bc932077ad232f914351932353e2e8706851c870bca1f"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:60499fe961b21264e17a471ec296dcbf4365fbea611bf9e303ab69db7159ce61"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-win32.whl", hash = "sha256:d9b342c76003c6b9336a80efcc766748a333573abf9350f4094ee46b006ec18f"},
|
||||
{file = "lxml-5.2.2-cp37-cp37m-win_amd64.whl", hash = "sha256:b16db2770517b8799c79aa80f4053cd6f8b716f21f8aca962725a9565ce3ee40"},
|
||||
@@ -1883,24 +1874,6 @@ files = [
|
||||
{file = "MarkupSafe-2.1.5.tar.gz", hash = "sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mkl"
|
||||
version = "2021.4.0"
|
||||
description = "Intel® oneAPI Math Kernel Library"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "mkl-2021.4.0-py2.py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.whl", hash = "sha256:67460f5cd7e30e405b54d70d1ed3ca78118370b65f7327d495e9c8847705e2fb"},
|
||||
{file = "mkl-2021.4.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:636d07d90e68ccc9630c654d47ce9fdeb036bb46e2b193b3a9ac8cfea683cce5"},
|
||||
{file = "mkl-2021.4.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:398dbf2b0d12acaf54117a5210e8f191827f373d362d796091d161f610c1ebfb"},
|
||||
{file = "mkl-2021.4.0-py2.py3-none-win32.whl", hash = "sha256:439c640b269a5668134e3dcbcea4350459c4a8bc46469669b2d67e07e3d330e8"},
|
||||
{file = "mkl-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:ceef3cafce4c009dd25f65d7ad0d833a0fbadc3d8903991ec92351fe5de1e718"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
intel-openmp = "==2021.*"
|
||||
tbb = "==2021.*"
|
||||
|
||||
[[package]]
|
||||
name = "moviepy"
|
||||
version = "1.0.3"
|
||||
@@ -2292,12 +2265,13 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "nvidia-cudnn-cu12"
|
||||
version = "8.9.2.26"
|
||||
version = "9.1.0.70"
|
||||
description = "cuDNN runtime libraries"
|
||||
optional = false
|
||||
python-versions = ">=3"
|
||||
files = [
|
||||
{file = "nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl", hash = "sha256:5ccb288774fdfb07a7e7025ffec286971c06d8d7b4fb162525334616d7629ff9"},
|
||||
{file = "nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f"},
|
||||
{file = "nvidia_cudnn_cu12-9.1.0.70-py3-none-win_amd64.whl", hash = "sha256:6278562929433d68365a07a4a1546c237ba2849852c0d4b2262a486e805b977a"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -2357,13 +2331,12 @@ nvidia-nvjitlink-cu12 = "*"
|
||||
|
||||
[[package]]
|
||||
name = "nvidia-nccl-cu12"
|
||||
version = "2.20.5"
|
||||
version = "2.21.5"
|
||||
description = "NVIDIA Collective Communication Library (NCCL) Runtime"
|
||||
optional = false
|
||||
python-versions = ">=3"
|
||||
files = [
|
||||
{file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1fc150d5c3250b170b29410ba682384b14581db722b2531b0d8d33c595f33d01"},
|
||||
{file = "nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:057f6bf9685f75215d0c53bf3ac4a10b3e6578351de307abad9e18a99182af56"},
|
||||
{file = "nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2373,6 +2346,7 @@ description = "Nvidia JIT LTO Library"
|
||||
optional = false
|
||||
python-versions = ">=3"
|
||||
files = [
|
||||
{file = "nvidia_nvjitlink_cu12-12.5.82-py3-none-manylinux2014_aarch64.whl", hash = "sha256:98103729cc5226e13ca319a10bbf9433bbbd44ef64fe72f45f067cacc14b8d27"},
|
||||
{file = "nvidia_nvjitlink_cu12-12.5.82-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f9b37bc5c8cf7509665cb6ada5aaa0ce65618f2332b7d3e78e9790511f111212"},
|
||||
{file = "nvidia_nvjitlink_cu12-12.5.82-py3-none-win_amd64.whl", hash = "sha256:e782564d705ff0bf61ac3e1bf730166da66dd2fe9012f111ede5fc49b64ae697"},
|
||||
]
|
||||
@@ -3240,6 +3214,33 @@ files = [
|
||||
[package.dependencies]
|
||||
six = ">=1.10.0"
|
||||
|
||||
[[package]]
|
||||
name = "pytorch-triton"
|
||||
version = "3.0.0+dedb7bdf33"
|
||||
description = "A language and compiler for custom Deep Learning operations"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "pytorch_triton-3.0.0+dedb7bdf33-cp310-cp310-linux_x86_64.whl", hash = "sha256:55a025ff4fc61eff466e08eccd7c6bcebaae971f59094157f552074bbb618789"},
|
||||
{file = "pytorch_triton-3.0.0+dedb7bdf33-cp311-cp311-linux_x86_64.whl", hash = "sha256:46c9fd3d1c02a62201d98759ba4318ad89ca6c871a4c4a076f1ae667cc5a0eb7"},
|
||||
{file = "pytorch_triton-3.0.0+dedb7bdf33-cp312-cp312-linux_x86_64.whl", hash = "sha256:e7e3f94c08e389302944bc03cb6d4368560d5da4283799f2c8cd4f57727b1b54"},
|
||||
{file = "pytorch_triton-3.0.0+dedb7bdf33-cp38-cp38-linux_x86_64.whl", hash = "sha256:a095aabfbd38b3d596c859ac25e0352c869b8e659bc35093edf97794a59e6164"},
|
||||
{file = "pytorch_triton-3.0.0+dedb7bdf33-cp39-cp39-linux_x86_64.whl", hash = "sha256:5e65e0084a60b0001e5740c491225af81f312f81f6929f155225900f7ea9596b"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
filelock = "*"
|
||||
|
||||
[package.extras]
|
||||
build = ["cmake (>=3.20)", "lit"]
|
||||
tests = ["autopep8", "flake8", "isort", "llnl-hatchet", "numpy", "pytest", "scipy (>=1.7.1)"]
|
||||
tutorials = ["matplotlib", "pandas", "tabulate"]
|
||||
|
||||
[package.source]
|
||||
type = "legacy"
|
||||
url = "https://download.pytorch.org/whl/nightly/cu121"
|
||||
reference = "pytorch-nightly"
|
||||
|
||||
[[package]]
|
||||
name = "pytz"
|
||||
version = "2024.1"
|
||||
@@ -3276,6 +3277,7 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||
@@ -3925,19 +3927,6 @@ mpmath = ">=1.1.0,<1.4"
|
||||
[package.extras]
|
||||
dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "tbb"
|
||||
version = "2021.13.0"
|
||||
description = "Intel® oneAPI Threading Building Blocks (oneTBB)"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "tbb-2021.13.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:a2567725329639519d46d92a2634cf61e76601dac2f777a05686fea546c4fe4f"},
|
||||
{file = "tbb-2021.13.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:aaf667e92849adb012b8874d6393282afc318aca4407fc62f912ee30a22da46a"},
|
||||
{file = "tbb-2021.13.0-py3-none-win32.whl", hash = "sha256:6669d26703e9943f6164c6407bd4a237a45007e79b8d3832fe6999576eaaa9ef"},
|
||||
{file = "tbb-2021.13.0-py3-none-win_amd64.whl", hash = "sha256:3528a53e4bbe64b07a6112b4c5a00ff3c61924ee46c9c68e004a1ac7ad1f09c3"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "termcolor"
|
||||
version = "2.4.0"
|
||||
@@ -3982,95 +3971,86 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "torch"
|
||||
version = "2.3.1"
|
||||
version = "2.5.0.dev20240726+cu121"
|
||||
description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
|
||||
optional = false
|
||||
python-versions = ">=3.8.0"
|
||||
files = [
|
||||
{file = "torch-2.3.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:605a25b23944be5ab7c3467e843580e1d888b8066e5aaf17ff7bf9cc30001cc3"},
|
||||
{file = "torch-2.3.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:f2357eb0965583a0954d6f9ad005bba0091f956aef879822274b1bcdb11bd308"},
|
||||
{file = "torch-2.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:32b05fe0d1ada7f69c9f86c14ff69b0ef1957a5a54199bacba63d22d8fab720b"},
|
||||
{file = "torch-2.3.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:7c09a94362778428484bcf995f6004b04952106aee0ef45ff0b4bab484f5498d"},
|
||||
{file = "torch-2.3.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:b2ec81b61bb094ea4a9dee1cd3f7b76a44555375719ad29f05c0ca8ef596ad39"},
|
||||
{file = "torch-2.3.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:490cc3d917d1fe0bd027057dfe9941dc1d6d8e3cae76140f5dd9a7e5bc7130ab"},
|
||||
{file = "torch-2.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:5802530783bd465fe66c2df99123c9a54be06da118fbd785a25ab0a88123758a"},
|
||||
{file = "torch-2.3.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:a7dd4ed388ad1f3d502bf09453d5fe596c7b121de7e0cfaca1e2017782e9bbac"},
|
||||
{file = "torch-2.3.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:a486c0b1976a118805fc7c9641d02df7afbb0c21e6b555d3bb985c9f9601b61a"},
|
||||
{file = "torch-2.3.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:224259821fe3e4c6f7edf1528e4fe4ac779c77addaa74215eb0b63a5c474d66c"},
|
||||
{file = "torch-2.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:e5fdccbf6f1334b2203a61a0e03821d5845f1421defe311dabeae2fc8fbeac2d"},
|
||||
{file = "torch-2.3.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:3c333dc2ebc189561514eda06e81df22bf8fb64e2384746b2cb9f04f96d1d4c8"},
|
||||
{file = "torch-2.3.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:07e9ba746832b8d069cacb45f312cadd8ad02b81ea527ec9766c0e7404bb3feb"},
|
||||
{file = "torch-2.3.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:462d1c07dbf6bb5d9d2f3316fee73a24f3d12cd8dacf681ad46ef6418f7f6626"},
|
||||
{file = "torch-2.3.1-cp38-cp38-win_amd64.whl", hash = "sha256:ff60bf7ce3de1d43ad3f6969983f321a31f0a45df3690921720bcad6a8596cc4"},
|
||||
{file = "torch-2.3.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:bee0bd33dc58aa8fc8a7527876e9b9a0e812ad08122054a5bff2ce5abf005b10"},
|
||||
{file = "torch-2.3.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:aaa872abde9a3d4f91580f6396d54888620f4a0b92e3976a6034759df4b961ad"},
|
||||
{file = "torch-2.3.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:3d7a7f7ef21a7520510553dc3938b0c57c116a7daee20736a9e25cbc0e832bdc"},
|
||||
{file = "torch-2.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:4777f6cefa0c2b5fa87223c213e7b6f417cf254a45e5829be4ccd1b2a4ee1011"},
|
||||
{file = "torch-2.3.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:2bb5af780c55be68fe100feb0528d2edebace1d55cb2e351de735809ba7391eb"},
|
||||
{file = "torch-2.5.0.dev20240726+cu121-cp310-cp310-linux_x86_64.whl", hash = "sha256:ecf50e5394403518c01c0906ef1e19bf7816a7a17ef733b401b47e389988546d"},
|
||||
{file = "torch-2.5.0.dev20240726+cu121-cp310-cp310-win_amd64.whl", hash = "sha256:6ea949aa1d2d186692f29e6f2bb03eb7bde5bd3d3232990fe4f66fb6fcbb5eef"},
|
||||
{file = "torch-2.5.0.dev20240726+cu121-cp311-cp311-linux_x86_64.whl", hash = "sha256:fc85bb92112b9a9691b23110ab40cb585826c1549a5d5e3ca9b58047162c036d"},
|
||||
{file = "torch-2.5.0.dev20240726+cu121-cp311-cp311-win_amd64.whl", hash = "sha256:7c4cb56b84bebb30ab9f20fd8be368409dd251d2e0d684705582ad80a5d725de"},
|
||||
{file = "torch-2.5.0.dev20240726+cu121-cp312-cp312-linux_x86_64.whl", hash = "sha256:d89f741ccd90073be452b293779e5ea58c10bd9dce2cc542355747d5a4c46a0e"},
|
||||
{file = "torch-2.5.0.dev20240726+cu121-cp312-cp312-win_amd64.whl", hash = "sha256:9b0a8770fddc28b42ade03e484bb7e42825b564fb9dbc34ddbc5e2bae5a04bab"},
|
||||
{file = "torch-2.5.0.dev20240726+cu121-cp313-cp313-linux_x86_64.whl", hash = "sha256:0e95dfd30616729f1ac5ceac1903fbdf336f905760949e6cfffbb1fb8de6f44c"},
|
||||
{file = "torch-2.5.0.dev20240726+cu121-cp38-cp38-linux_x86_64.whl", hash = "sha256:9daf32a6fe045d9d2f8af7fb2de07d7d37c79fb491ffc99a6470aa9296874051"},
|
||||
{file = "torch-2.5.0.dev20240726+cu121-cp38-cp38-win_amd64.whl", hash = "sha256:a5ed07b95ba63ecf6f952eef99cb0d71b07057034aab03f8cc904f0ddb4754b5"},
|
||||
{file = "torch-2.5.0.dev20240726+cu121-cp39-cp39-linux_x86_64.whl", hash = "sha256:7dc57e52f4dd965d2d963be80b0dac538983ae813db37adf2758115a05ad75bb"},
|
||||
{file = "torch-2.5.0.dev20240726+cu121-cp39-cp39-win_amd64.whl", hash = "sha256:53e4ea263449eb32cb130f2d33ea976a206705a1d695c06c4b2fc4e8f609ba0f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
filelock = "*"
|
||||
fsspec = "*"
|
||||
jinja2 = "*"
|
||||
mkl = {version = ">=2021.1.1,<=2021.4.0", markers = "platform_system == \"Windows\""}
|
||||
networkx = "*"
|
||||
nvidia-cublas-cu12 = {version = "12.1.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
||||
nvidia-cuda-cupti-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
||||
nvidia-cuda-nvrtc-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
||||
nvidia-cuda-runtime-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
||||
nvidia-cudnn-cu12 = {version = "8.9.2.26", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
||||
nvidia-cudnn-cu12 = {version = "9.1.0.70", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
||||
nvidia-cufft-cu12 = {version = "11.0.2.54", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
||||
nvidia-curand-cu12 = {version = "10.3.2.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
||||
nvidia-cusolver-cu12 = {version = "11.4.5.107", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
||||
nvidia-cusparse-cu12 = {version = "12.1.0.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
||||
nvidia-nccl-cu12 = {version = "2.20.5", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
||||
nvidia-nccl-cu12 = {version = "2.21.5", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
||||
nvidia-nvtx-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
||||
sympy = "*"
|
||||
triton = {version = "2.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\""}
|
||||
pytorch-triton = {version = "3.0.0+dedb7bdf33", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\""}
|
||||
sympy = {version = ">=1.13.0", markers = "python_version >= \"3.9\""}
|
||||
typing-extensions = ">=4.8.0"
|
||||
|
||||
[package.extras]
|
||||
opt-einsum = ["opt-einsum (>=3.3)"]
|
||||
optree = ["optree (>=0.9.1)"]
|
||||
optree = ["optree (>=0.12.0)"]
|
||||
|
||||
[package.source]
|
||||
type = "legacy"
|
||||
url = "https://download.pytorch.org/whl/nightly/cu121"
|
||||
reference = "pytorch-nightly"
|
||||
|
||||
[[package]]
|
||||
name = "torchvision"
|
||||
version = "0.18.1"
|
||||
version = "0.20.0.dev20240726+cu121"
|
||||
description = "image and video datasets and models for torch deep learning"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "torchvision-0.18.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3e694e54b0548dad99c12af6bf0c8e4f3350137d391dcd19af22a1c5f89322b3"},
|
||||
{file = "torchvision-0.18.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:0b3bda0aa5b416eeb547143b8eeaf17720bdba9cf516dc991aacb81811aa96a5"},
|
||||
{file = "torchvision-0.18.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:573ff523c739405edb085f65cb592f482d28a30e29b0be4c4ba08040b3ae785f"},
|
||||
{file = "torchvision-0.18.1-cp310-cp310-win_amd64.whl", hash = "sha256:ef7bbbc60b38e831a75e547c66ca1784f2ac27100f9e4ddbe9614cef6cbcd942"},
|
||||
{file = "torchvision-0.18.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:80b5d794dd0fdba787adc22f1a367a5ead452327686473cb260dd94364bc56a6"},
|
||||
{file = "torchvision-0.18.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:9077cf590cdb3a5e8fdf5cdb71797f8c67713f974cf0228ecb17fcd670ab42f9"},
|
||||
{file = "torchvision-0.18.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:ceb993a882f1ae7ae373ed39c28d7e3e802205b0e59a7ed84ef4028f0bba8d7f"},
|
||||
{file = "torchvision-0.18.1-cp311-cp311-win_amd64.whl", hash = "sha256:52f7436140045dc2239cdc502aa76b2bd8bd676d64244ff154d304aa69852046"},
|
||||
{file = "torchvision-0.18.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2be6f0bf7c455c89a51a1dbb6f668d36c6edc479f49ac912d745d10df5715657"},
|
||||
{file = "torchvision-0.18.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:f118d887bfde3a948a41d56587525401e5cac1b7db2eaca203324d6ed2b1caca"},
|
||||
{file = "torchvision-0.18.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:13d24d904f65e62d66a1e0c41faec630bc193867b8a4a01166769e8a8e8df8e9"},
|
||||
{file = "torchvision-0.18.1-cp312-cp312-win_amd64.whl", hash = "sha256:ed6340b69a63a625e512a66127210d412551d9c5f2ad2978130c6a45bf56cd4a"},
|
||||
{file = "torchvision-0.18.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b1c3864fa9378c88bce8ad0ef3599f4f25397897ce612e1c245c74b97092f35e"},
|
||||
{file = "torchvision-0.18.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:02085a2ffc7461f5c0edb07d6f3455ee1806561f37736b903da820067eea58c7"},
|
||||
{file = "torchvision-0.18.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:9726c316a2501df8503e5a5dc46a631afd4c515a958972e5b7f7b9c87d2125c0"},
|
||||
{file = "torchvision-0.18.1-cp38-cp38-win_amd64.whl", hash = "sha256:64a2662dbf30db9055d8b201d6e56f312a504e5ccd9d144c57c41622d3c524cb"},
|
||||
{file = "torchvision-0.18.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:975b8594c0f5288875408acbb74946eea786c5b008d129c0d045d0ead23742bc"},
|
||||
{file = "torchvision-0.18.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:da83c8bbd34d8bee48bfa1d1b40e0844bc3cba10ed825a5a8cbe3ce7b62264cd"},
|
||||
{file = "torchvision-0.18.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:54bfcd352abb396d5c9c237d200167c178bd136051b138e1e8ef46ce367c2773"},
|
||||
{file = "torchvision-0.18.1-cp39-cp39-win_amd64.whl", hash = "sha256:5c8366a1aeee49e9ea9e64b30d199debdf06b1bd7610a76165eb5d7869c3bde5"},
|
||||
{file = "torchvision-0.20.0.dev20240726+cu121-cp310-cp310-linux_x86_64.whl", hash = "sha256:d515b6989902d13303b0863081be9c338321b8786584d830f19d3935888fe37a"},
|
||||
{file = "torchvision-0.20.0.dev20240726+cu121-cp310-cp310-win_amd64.whl", hash = "sha256:bdeb4c2bf70ee8c0643fb1a9112b06214abf6065d92a72ea515230c0effa2f1f"},
|
||||
{file = "torchvision-0.20.0.dev20240726+cu121-cp311-cp311-linux_x86_64.whl", hash = "sha256:95b43e60834fce99c5a1c45e298d4a62feba2a3c3c62678978285b71f0e85d0c"},
|
||||
{file = "torchvision-0.20.0.dev20240726+cu121-cp311-cp311-win_amd64.whl", hash = "sha256:b0e2a0ee4e9dfd1842c43c60c54b671240785520f337bd14e15333653352dc5a"},
|
||||
{file = "torchvision-0.20.0.dev20240726+cu121-cp312-cp312-linux_x86_64.whl", hash = "sha256:642ba00eaf7b74599d20c367507bb0b70bd47f3b0d39ea543a3ccc0ffed82f4b"},
|
||||
{file = "torchvision-0.20.0.dev20240726+cu121-cp312-cp312-win_amd64.whl", hash = "sha256:e8ded5b6d1e1aae1fbaaba098fec610bdd1a0be32eddf914c180d663638eaf5e"},
|
||||
{file = "torchvision-0.20.0.dev20240726+cu121-cp38-cp38-linux_x86_64.whl", hash = "sha256:f8303f347c9c0d165a16fe3f0d6eefe9ad03bdbdb0bf4e6325c7b728cb4bf7a2"},
|
||||
{file = "torchvision-0.20.0.dev20240726+cu121-cp38-cp38-win_amd64.whl", hash = "sha256:e9e57eb6855cdfb3d597a05618214f1b41bdf81c8fd62c63c6f1caefbab6b9c5"},
|
||||
{file = "torchvision-0.20.0.dev20240726+cu121-cp39-cp39-linux_x86_64.whl", hash = "sha256:47fc41d8773a60a3a9fedcd5feacaaee874d5ce03cb5b92e4482f0a4c172c15f"},
|
||||
{file = "torchvision-0.20.0.dev20240726+cu121-cp39-cp39-win_amd64.whl", hash = "sha256:0995f390b8ba98da2624ff86542997fcb18a916dc59d63a54b5dcc5c7e607e11"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
numpy = "*"
|
||||
pillow = ">=5.3.0,<8.3.dev0 || >=8.4.dev0"
|
||||
torch = "2.3.1"
|
||||
torch = "2.5.0.dev20240726"
|
||||
|
||||
[package.extras]
|
||||
gdown = ["gdown (>=4.7.3)"]
|
||||
scipy = ["scipy"]
|
||||
|
||||
[package.source]
|
||||
type = "legacy"
|
||||
url = "https://download.pytorch.org/whl/nightly/cu121"
|
||||
reference = "pytorch-nightly"
|
||||
|
||||
[[package]]
|
||||
name = "tqdm"
|
||||
version = "4.66.4"
|
||||
@@ -4091,29 +4071,6 @@ notebook = ["ipywidgets (>=6)"]
|
||||
slack = ["slack-sdk"]
|
||||
telegram = ["requests"]
|
||||
|
||||
[[package]]
|
||||
name = "triton"
|
||||
version = "2.3.1"
|
||||
description = "A language and compiler for custom Deep Learning operations"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "triton-2.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c84595cbe5e546b1b290d2a58b1494df5a2ef066dd890655e5b8a8a92205c33"},
|
||||
{file = "triton-2.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9d64ae33bcb3a7a18081e3a746e8cf87ca8623ca13d2c362413ce7a486f893e"},
|
||||
{file = "triton-2.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaf80e8761a9e3498aa92e7bf83a085b31959c61f5e8ac14eedd018df6fccd10"},
|
||||
{file = "triton-2.3.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b13bf35a2b659af7159bf78e92798dc62d877aa991de723937329e2d382f1991"},
|
||||
{file = "triton-2.3.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63381e35ded3304704ea867ffde3b7cfc42c16a55b3062d41e017ef510433d66"},
|
||||
{file = "triton-2.3.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d968264523c7a07911c8fb51b4e0d1b920204dae71491b1fe7b01b62a31e124"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
filelock = "*"
|
||||
|
||||
[package.extras]
|
||||
build = ["cmake (>=3.20)", "lit"]
|
||||
tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)", "torch"]
|
||||
tutorials = ["matplotlib", "pandas", "tabulate", "torch"]
|
||||
|
||||
[[package]]
|
||||
name = "typing-extensions"
|
||||
version = "4.12.2"
|
||||
@@ -4485,4 +4442,4 @@ xarm = ["gym-xarm"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<3.13"
|
||||
content-hash = "dfe9c6a54e0382156e62e7bd2c7aab1be6372da76d30c61b06d27232276638cb"
|
||||
content-hash = "d205dcc3a9df41c8ced0c51341f415c408bcb7b09c8d74724ce4b2c3a1708fad"
|
||||
|
||||
@@ -25,6 +25,10 @@ classifiers=[
|
||||
]
|
||||
packages = [{include = "lerobot"}]
|
||||
|
||||
[[tool.poetry.source]]
|
||||
name = "pytorch-nightly"
|
||||
url = "https://download.pytorch.org/whl/nightly/cu121"
|
||||
secondary = true
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<3.13"
|
||||
@@ -38,10 +42,10 @@ einops = ">=0.8.0"
|
||||
pymunk = ">=6.6.0"
|
||||
zarr = ">=2.17.0"
|
||||
numba = ">=0.59.0"
|
||||
torch = ">=2.2.1"
|
||||
torch = { version = "*", source = "pytorch-nightly" }
|
||||
torchvision = { version = "*", source = "pytorch-nightly" }
|
||||
opencv-python = ">=4.9.0"
|
||||
diffusers = ">=0.27.2"
|
||||
torchvision = ">=0.17.1"
|
||||
h5py = ">=3.10.0"
|
||||
huggingface-hub = {extras = ["hf-transfer"], version = ">=0.23.0"}
|
||||
gymnasium = ">=0.29.1"
|
||||
|
||||
@@ -15,7 +15,9 @@
|
||||
# limitations under the License.
|
||||
import pytest
|
||||
|
||||
from .utils import DEVICE
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
|
||||
from .utils import DEVICE, KOCH_ROBOT_CONFIG_PATH
|
||||
|
||||
|
||||
def pytest_collection_finish():
|
||||
@@ -27,11 +29,12 @@ def is_koch_available():
|
||||
try:
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
|
||||
robot = make_robot("koch")
|
||||
robot_cfg = init_hydra_config(KOCH_ROBOT_CONFIG_PATH)
|
||||
robot = make_robot(robot_cfg)
|
||||
robot.connect()
|
||||
del robot
|
||||
return True
|
||||
except Exception as e:
|
||||
print("An alexander koch robot is not available.")
|
||||
print("A koch robot is not available.")
|
||||
print(e)
|
||||
return False
|
||||
|
||||
@@ -3,13 +3,19 @@ from pathlib import Path
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from lerobot.scripts.control_robot import record_dataset, replay_episode, run_policy, teleoperate
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_koch
|
||||
from lerobot.scripts.control_robot import record, replay, teleoperate
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, KOCH_ROBOT_CONFIG_PATH, require_koch
|
||||
|
||||
|
||||
def make_robot_():
|
||||
robot_cfg = init_hydra_config(KOCH_ROBOT_CONFIG_PATH)
|
||||
robot = make_robot(robot_cfg)
|
||||
return robot
|
||||
|
||||
|
||||
@require_koch
|
||||
def test_teleoperate(request):
|
||||
robot = make_robot("koch")
|
||||
robot = make_robot_()
|
||||
teleoperate(robot, teleop_time_s=1)
|
||||
teleoperate(robot, fps=30, teleop_time_s=1)
|
||||
teleoperate(robot, fps=60, teleop_time_s=1)
|
||||
@@ -17,20 +23,19 @@ def test_teleoperate(request):
|
||||
|
||||
|
||||
@require_koch
|
||||
def test_record_dataset_and_replay_episode_and_run_policy(tmpdir, request):
|
||||
robot_name = "koch"
|
||||
def test_record_and_replay_and_policy(tmpdir, request):
|
||||
env_name = "koch_real"
|
||||
policy_name = "act_koch_real"
|
||||
|
||||
root = Path(tmpdir)
|
||||
repo_id = "lerobot/debug"
|
||||
|
||||
robot = make_robot(robot_name)
|
||||
dataset = record_dataset(
|
||||
robot = make_robot_()
|
||||
dataset = record(
|
||||
robot, fps=30, root=root, repo_id=repo_id, warmup_time_s=1, episode_time_s=1, num_episodes=2
|
||||
)
|
||||
|
||||
replay_episode(robot, episode=0, fps=30, root=root, repo_id=repo_id)
|
||||
replay(robot, episode=0, fps=30, root=root, repo_id=repo_id)
|
||||
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH,
|
||||
@@ -43,6 +48,6 @@ def test_record_dataset_and_replay_episode_and_run_policy(tmpdir, request):
|
||||
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
||||
|
||||
run_policy(robot, policy, cfg, run_time_s=1)
|
||||
record(robot, policy, cfg, run_time_s=1)
|
||||
|
||||
del robot
|
||||
|
||||
@@ -1,33 +1,54 @@
|
||||
# TODO(rcadene): measure fps in nightly?
|
||||
# TODO(rcadene): test logs
|
||||
# TODO(rcadene): test calibration
|
||||
# TODO(rcadene): add compatibility with other motors bus
|
||||
|
||||
import time
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||
from tests.utils import require_koch
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from tests.utils import KOCH_ROBOT_CONFIG_PATH, require_koch
|
||||
|
||||
|
||||
def make_motors_bus():
|
||||
robot_cfg = init_hydra_config(KOCH_ROBOT_CONFIG_PATH)
|
||||
# Instantiating a common motors structure.
|
||||
# Here the one from Alexander Koch follower arm.
|
||||
motors_bus = hydra.utils.instantiate(robot_cfg.leader_arms.main)
|
||||
return motors_bus
|
||||
|
||||
|
||||
@require_koch
|
||||
def test_find_port(request):
|
||||
from lerobot.common.robot_devices.motors.dynamixel import find_port
|
||||
|
||||
find_port()
|
||||
|
||||
|
||||
@require_koch
|
||||
def test_configure_motors_all_ids_1(request):
|
||||
# This test expect the configuration was already correct.
|
||||
motors_bus = make_motors_bus()
|
||||
motors_bus.connect()
|
||||
motors_bus.write("Baud_Rate", [0] * len(motors_bus.motors))
|
||||
motors_bus.set_bus_baudrate(9_600)
|
||||
motors_bus.write("ID", [1] * len(motors_bus.motors))
|
||||
del motors_bus
|
||||
|
||||
# Test configure
|
||||
motors_bus = make_motors_bus()
|
||||
motors_bus.connect()
|
||||
assert motors_bus.are_motors_configured()
|
||||
del motors_bus
|
||||
|
||||
|
||||
@require_koch
|
||||
def test_motors_bus(request):
|
||||
# TODO(rcadene): measure fps in nightly?
|
||||
# TODO(rcadene): test logs
|
||||
# TODO(rcadene): test calibration
|
||||
# TODO(rcadene): add compatibility with other motors bus
|
||||
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
|
||||
|
||||
# Test instantiating a common motors structure.
|
||||
# Here the one from Alexander Koch follower arm.
|
||||
port = "/dev/tty.usbmodem575E0032081"
|
||||
motors = {
|
||||
# name: (index, model)
|
||||
"shoulder_pan": (1, "xl430-w250"),
|
||||
"shoulder_lift": (2, "xl430-w250"),
|
||||
"elbow_flex": (3, "xl330-m288"),
|
||||
"wrist_flex": (4, "xl330-m288"),
|
||||
"wrist_roll": (5, "xl330-m288"),
|
||||
"gripper": (6, "xl330-m288"),
|
||||
}
|
||||
motors_bus = DynamixelMotorsBus(port, motors)
|
||||
motors_bus = make_motors_bus()
|
||||
|
||||
# Test reading and writting before connecting raises an error
|
||||
with pytest.raises(RobotDeviceNotConnectedError):
|
||||
@@ -41,7 +62,7 @@ def test_motors_bus(request):
|
||||
del motors_bus
|
||||
|
||||
# Test connecting
|
||||
motors_bus = DynamixelMotorsBus(port, motors)
|
||||
motors_bus = make_motors_bus()
|
||||
motors_bus.connect()
|
||||
|
||||
# Test connecting twice raises an error
|
||||
@@ -52,7 +73,7 @@ def test_motors_bus(request):
|
||||
motors_bus.write("Torque_Enable", 0)
|
||||
values = motors_bus.read("Torque_Enable")
|
||||
assert isinstance(values, np.ndarray)
|
||||
assert len(values) == len(motors)
|
||||
assert len(values) == len(motors_bus.motors)
|
||||
assert (values == 0).all()
|
||||
|
||||
# Test writing torque on a specific motor
|
||||
@@ -83,10 +104,3 @@ def test_motors_bus(request):
|
||||
time.sleep(1)
|
||||
new_values = motors_bus.read("Present_Position")
|
||||
assert (new_values == values).all()
|
||||
|
||||
|
||||
@require_koch
|
||||
def test_find_port(request):
|
||||
from lerobot.common.robot_devices.motors.dynamixel import find_port
|
||||
|
||||
find_port()
|
||||
|
||||
@@ -23,6 +23,7 @@ from lerobot.common.utils.import_utils import is_package_available
|
||||
|
||||
# Pass this as the first argument to init_hydra_config.
|
||||
DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
|
||||
KOCH_ROBOT_CONFIG_PATH = "lerobot/configs/robot/koch.yaml"
|
||||
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
@@ -161,6 +162,7 @@ def require_koch(func):
|
||||
if request is None:
|
||||
raise ValueError("The 'request' fixture must be passed to the test function as a parameter.")
|
||||
|
||||
# The function `is_koch_available` is defined in `tests/conftest.py`
|
||||
if not request.getfixturevalue("is_koch_available"):
|
||||
pytest.skip("An alexander koch robot is not available.")
|
||||
return func(*args, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user