feat: Add fixes and refactor lekiwi example (#1396)
* feat: Add fixes and refactor lekiwi example * fix: replace repo_id with placeholders * feat: use record_loop for lekiwi, use same control strucutre as record.py * feat: make rerun log more general for lekiwi * fix: add comments record_loop and fix params evaluate.py * fix: add events in evaluate.py * fix: add events 2 * change record to display data * Integrate feedback steven * Add docs merging * fix: add lekiwi name check * fix: integrate feedback steven * fix: list for type * fix: check type list * remove second robot connect * fix: added file when merging * fix(record): account for edge cases when teleop is a list --------- Co-authored-by: Steven Palma <steven.palma@huggingface.co>
This commit is contained in:
@@ -40,9 +40,7 @@ import time
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
|
||||
import numpy as np
|
||||
import rerun as rr
|
||||
from typing import List
|
||||
|
||||
from lerobot.cameras import ( # noqa: F401
|
||||
CameraConfig, # noqa: F401
|
||||
@@ -72,6 +70,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
so100_leader,
|
||||
so101_leader,
|
||||
)
|
||||
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop
|
||||
from lerobot.utils.control_utils import (
|
||||
init_keyboard_listener,
|
||||
is_headless,
|
||||
@@ -85,7 +84,7 @@ from lerobot.utils.utils import (
|
||||
init_logging,
|
||||
log_say,
|
||||
)
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
from lerobot.utils.visualization_utils import _init_rerun, log_rerun_data
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -165,7 +164,7 @@ def record_loop(
|
||||
events: dict,
|
||||
fps: int,
|
||||
dataset: LeRobotDataset | None = None,
|
||||
teleop: Teleoperator | None = None,
|
||||
teleop: Teleoperator | List[Teleoperator] | None = None,
|
||||
policy: PreTrainedPolicy | None = None,
|
||||
control_time_s: int | None = None,
|
||||
single_task: str | None = None,
|
||||
@@ -174,6 +173,23 @@ def record_loop(
|
||||
if dataset is not None and dataset.fps != fps:
|
||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset.fps} != {fps}).")
|
||||
|
||||
teleop_arm = teleop_keyboard = None
|
||||
if isinstance(teleop, list):
|
||||
teleop_keyboard = next((t for t in teleop if isinstance(t, KeyboardTeleop)), None)
|
||||
teleop_arm = next(
|
||||
(
|
||||
t
|
||||
for t in teleop
|
||||
if isinstance(t, (so100_leader.SO100Leader, so101_leader.SO101Leader, koch_leader.KochLeader))
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not (teleop_arm and teleop_keyboard and len(teleop) == 2 and robot.name == "lekiwi_client"):
|
||||
raise ValueError(
|
||||
"For multi-teleop, the list must contain exactly one KeyboardTeleop and one arm teleoperator. Currently only supported for LeKiwi robot."
|
||||
)
|
||||
|
||||
# if policy is given it needs cleaning up
|
||||
if policy is not None:
|
||||
policy.reset()
|
||||
@@ -202,8 +218,17 @@ def record_loop(
|
||||
robot_type=robot.robot_type,
|
||||
)
|
||||
action = {key: action_values[i].item() for i, key in enumerate(robot.action_features)}
|
||||
elif policy is None and teleop is not None:
|
||||
elif policy is None and isinstance(teleop, Teleoperator):
|
||||
action = teleop.get_action()
|
||||
elif policy is None and isinstance(teleop, list):
|
||||
# TODO(pepijn, steven): clean the record loop for use of multiple robots (possibly with pipeline)
|
||||
arm_action = teleop_arm.get_action()
|
||||
arm_action = {f"arm_{k}": v for k, v in arm_action.items()}
|
||||
|
||||
keyboard_action = teleop_keyboard.get_action()
|
||||
base_action = robot._from_keyboard_to_base_action(keyboard_action)
|
||||
|
||||
action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
else:
|
||||
logging.info(
|
||||
"No policy or teleoperator provided, skipping action generation."
|
||||
@@ -222,14 +247,7 @@ def record_loop(
|
||||
dataset.add_frame(frame, task=single_task)
|
||||
|
||||
if display_data:
|
||||
for obs, val in observation.items():
|
||||
if isinstance(val, float):
|
||||
rr.log(f"observation.{obs}", rr.Scalar(val))
|
||||
elif isinstance(val, np.ndarray):
|
||||
rr.log(f"observation.{obs}", rr.Image(val), static=True)
|
||||
for act, val in action.items():
|
||||
if isinstance(val, float):
|
||||
rr.log(f"action.{act}", rr.Scalar(val))
|
||||
log_rerun_data(observation, action)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
@@ -22,7 +22,6 @@ from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
from lerobot.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
@@ -195,26 +194,23 @@ class LeKiwiClient(Robot):
|
||||
self, observation: Dict[str, Any]
|
||||
) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
|
||||
"""Extracts frames, and state from the parsed observation."""
|
||||
flat_state = {key: value for key, value in observation.items() if key in self._state_ft}
|
||||
|
||||
state_vec = np.array(
|
||||
[flat_state.get(k, 0.0) for k in self._state_order],
|
||||
dtype=np.float32,
|
||||
)
|
||||
flat_state = {key: observation.get(key, 0.0) for key in self._state_order}
|
||||
|
||||
state_vec = np.array([flat_state[key] for key in self._state_order], dtype=np.float32)
|
||||
|
||||
obs_dict: Dict[str, Any] = {**flat_state, "observation.state": state_vec}
|
||||
|
||||
# Decode images
|
||||
image_observation = {
|
||||
f"observation.images.{key}": value
|
||||
for key, value in observation.items()
|
||||
if key in self._cameras_ft
|
||||
}
|
||||
current_frames: Dict[str, np.ndarray] = {}
|
||||
for cam_name, image_b64 in image_observation.items():
|
||||
for cam_name, image_b64 in observation.items():
|
||||
if cam_name not in self._cameras_ft:
|
||||
continue
|
||||
frame = self._decode_image_from_b64(image_b64)
|
||||
if frame is not None:
|
||||
current_frames[cam_name] = frame
|
||||
|
||||
return current_frames, {"observation.state": state_vec}
|
||||
return current_frames, obs_dict
|
||||
|
||||
def _get_data(self) -> Tuple[Dict[str, np.ndarray], Dict[str, Any], Dict[str, Any]]:
|
||||
"""
|
||||
@@ -267,7 +263,7 @@ class LeKiwiClient(Robot):
|
||||
if frame is None:
|
||||
logging.warning("Frame is None")
|
||||
frame = np.zeros((640, 480, 3), dtype=np.uint8)
|
||||
obs_dict[cam_name] = torch.from_numpy(frame)
|
||||
obs_dict[cam_name] = frame
|
||||
|
||||
return obs_dict
|
||||
|
||||
@@ -327,7 +323,10 @@ class LeKiwiClient(Robot):
|
||||
|
||||
# TODO(Steven): Remove the np conversion when it is possible to record a non-numpy array value
|
||||
actions = np.array([action.get(k, 0.0) for k in self._state_order], dtype=np.float32)
|
||||
return {"action": actions}
|
||||
|
||||
action_sent = {key: actions[i] for i, key in enumerate(self._state_order)}
|
||||
action_sent["action"] = actions
|
||||
return action_sent
|
||||
|
||||
def disconnect(self):
|
||||
"""Cleans ZMQ comms"""
|
||||
|
||||
@@ -36,7 +36,6 @@ from dataclasses import asdict, dataclass
|
||||
from pprint import pformat
|
||||
|
||||
import draccus
|
||||
import numpy as np
|
||||
import rerun as rr
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
@@ -60,11 +59,12 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
)
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import init_logging, move_cursor_up
|
||||
from lerobot.utils.visualization_utils import _init_rerun
|
||||
from lerobot.utils.visualization_utils import _init_rerun, log_rerun_data
|
||||
|
||||
|
||||
@dataclass
|
||||
class TeleoperateConfig:
|
||||
# TODO: pepijn, steven: if more robots require multiple teleoperators (like lekiwi) its good to make this possibele in teleop.py and record.py with List[Teleoperator]
|
||||
teleop: TeleoperatorConfig
|
||||
robot: RobotConfig
|
||||
# Limit the maximum frames per second.
|
||||
@@ -84,14 +84,7 @@ def teleop_loop(
|
||||
action = teleop.get_action()
|
||||
if display_data:
|
||||
observation = robot.get_observation()
|
||||
for obs, val in observation.items():
|
||||
if isinstance(val, float):
|
||||
rr.log(f"observation_{obs}", rr.Scalar(val))
|
||||
elif isinstance(val, np.ndarray):
|
||||
rr.log(f"observation_{obs}", rr.Image(val), static=True)
|
||||
for act, val in action.items():
|
||||
if isinstance(val, float):
|
||||
rr.log(f"action_{act}", rr.Scalar(val))
|
||||
log_rerun_data(observation, action)
|
||||
|
||||
robot.send_action(action)
|
||||
dt_s = time.perf_counter() - loop_start
|
||||
|
||||
@@ -13,7 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import rerun as rr
|
||||
|
||||
|
||||
@@ -24,3 +26,21 @@ def _init_rerun(session_name: str = "lerobot_control_loop") -> None:
|
||||
rr.init(session_name)
|
||||
memory_limit = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "10%")
|
||||
rr.spawn(memory_limit=memory_limit)
|
||||
|
||||
|
||||
def log_rerun_data(observation: dict[str | Any], action: dict[str | Any]):
|
||||
for obs, val in observation.items():
|
||||
if isinstance(val, float):
|
||||
rr.log(f"observation.{obs}", rr.Scalar(val))
|
||||
elif isinstance(val, np.ndarray):
|
||||
if val.ndim == 1:
|
||||
for i, v in enumerate(val):
|
||||
rr.log(f"observation.{obs}_{i}", rr.Scalar(float(v)))
|
||||
else:
|
||||
rr.log(f"observation.{obs}", rr.Image(val), static=True)
|
||||
for act, val in action.items():
|
||||
if isinstance(val, float):
|
||||
rr.log(f"action.{act}", rr.Scalar(val))
|
||||
elif isinstance(val, np.ndarray):
|
||||
for i, v in enumerate(val):
|
||||
rr.log(f"action.{act}_{i}", rr.Scalar(float(v)))
|
||||
|
||||
Reference in New Issue
Block a user