@@ -226,6 +226,42 @@ class ManipulatorRobot:
|
||||
self.is_connected = False
|
||||
self.logs = {}
|
||||
|
||||
def get_motor_names(self, arm: dict[str, MotorsBus]) -> list:
|
||||
return [f"{arm}_{motor}" for arm, bus in arm.items() for motor in bus.motors]
|
||||
|
||||
@property
|
||||
def camera_features(self) -> dict:
|
||||
cam_ft = {}
|
||||
for cam_key, cam in self.cameras.items():
|
||||
key = f"observation.images.{cam_key}"
|
||||
cam_ft[key] = {
|
||||
"shape": (cam.height, cam.width, cam.channels),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
}
|
||||
return cam_ft
|
||||
|
||||
@property
|
||||
def motor_features(self) -> dict:
|
||||
action_names = self.get_motor_names(self.leader_arms)
|
||||
state_names = self.get_motor_names(self.leader_arms)
|
||||
return {
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(action_names),),
|
||||
"names": action_names,
|
||||
},
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(state_names),),
|
||||
"names": state_names,
|
||||
},
|
||||
}
|
||||
|
||||
@property
|
||||
def features(self):
|
||||
return {**self.motor_features, **self.camera_features}
|
||||
|
||||
@property
|
||||
def has_camera(self):
|
||||
return len(self.cameras) > 0
|
||||
|
||||
@@ -11,6 +11,7 @@ def get_arm_id(name, arm_type):
|
||||
class Robot(Protocol):
|
||||
# TODO(rcadene, aliberts): Add unit test checking the protocol is implemented in the corresponding classes
|
||||
robot_type: str
|
||||
features: dict
|
||||
|
||||
def connect(self): ...
|
||||
def run_calibration(self): ...
|
||||
|
||||
Reference in New Issue
Block a user