import logging from functools import cached_property import time from typing import Any from .config import MixRobotConfig from lerobot.robots.robot import Robot from lerobot.utils.errors import DeviceNotConnectedError from lerobot.robots import Robot,make_robot_from_config logger = logging.getLogger(__name__) class MixRobot(Robot): config_class = MixRobotConfig name = "mix" def __init__(self, config: MixRobotConfig): super().__init__(config) self.config = config self.robotList = {} for name,config in self.config.robotList.items(): self.robotList[name] = make_robot_from_config(config) @property def _motors_ft(self) -> dict[str, type]: ret = {} for name,robot in self.robotList.items(): for pname,motor in robot._motors_ft: pname = pname.replace(".pos", "") ret[f"{pname}_{name}.pos"] = motor return ret @property def _cameras_ft(self) -> dict[str, tuple]: ret = {} for name,robot in self.robotList.items(): for pname,cam in robot._cameras_ft: ret[f"{pname}.{name}"] = cam return ret @cached_property def observation_features(self) -> dict[str, type | tuple]: return {**self._motors_ft, **self._cameras_ft} @cached_property def action_features(self) -> dict[str, type]: ret = {} for name,robot in self.robotList.items(): for pname,typeVal in robot.action_features.items(): ret[f"{name}.{pname}"] = typeVal return ret @property def is_connected(self) -> bool: return all([robot.is_connected for robot in self.robotList.values()]) def connect(self, calibrate: bool = True) -> None: for name,robot in self.robotList.items(): robot.connect(calibrate) @property def is_calibrated(self) -> bool: return all([robot.is_calibrated for robot in self.robotList.values()]) def calibrate(self) -> None: pass def configure(self) -> None: pass def setup_motors(self) -> None: pass def get_observation(self) -> dict[str, Any]: if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") obs_dict = {} for name,robot in self.robotList.items(): obs_dict.update({f"{name}.{kname}":item for kname,item in robot.get_observation().items()}) return obs_dict def send_action(self, action: dict[str, Any]) -> dict[str, Any]: if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") dic = {} for key,value in action.items(): name = key.split(".")[0] if name not in dic: dic[name] = {} dic[name][key.removeprefix(name+".")] = value for dname,dvalue in dic.items(): if dname not in self.robotList: continue self.robotList[dname].send_action(dvalue) def disconnect(self): if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") for name,robot in self.robotList.items(): robot.disconnect()