forked from tangger/lerobot
updating with adding masking in ACT - start adding some tests
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import cv2
|
||||
import gymnasium as gym
|
||||
@@ -23,6 +24,14 @@ CAMERAS_PORTS = {
|
||||
LEADER_PORT = "/dev/ttyACM1"
|
||||
FOLLOWER_PORT = "/dev/ttyACM0"
|
||||
|
||||
MockRobot = MagicMock()
|
||||
MockRobot.read_position = MagicMock()
|
||||
MockRobot.read_position.return_value = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
|
||||
|
||||
MockCamera = MagicMock()
|
||||
MockCamera.isOpened = MagicMock(return_value=True)
|
||||
MockCamera.read = MagicMock(return_value=(True, np.zeros((480, 640, 3), dtype=np.uint8)))
|
||||
|
||||
|
||||
def capture_image(cam, cam_width, cam_height):
|
||||
# Capture a single frame
|
||||
@@ -54,6 +63,7 @@ class RealEnv(gym.Env):
|
||||
trigger_torque=70,
|
||||
fps: int = FPS,
|
||||
fps_tolerance: float = 0.1,
|
||||
mock: bool = False,
|
||||
):
|
||||
self.num_joints = num_joints
|
||||
self.cameras_shapes = cameras_shapes
|
||||
@@ -68,15 +78,15 @@ class RealEnv(gym.Env):
|
||||
self.fps_tolerance = fps_tolerance
|
||||
|
||||
# Initialize the robot
|
||||
self.follower = Robot(device_name=self.follower_port)
|
||||
self.follower = Robot(device_name=self.follower_port) if not mock else MockRobot
|
||||
if self.record:
|
||||
self.leader = Robot(device_name=self.leader_port)
|
||||
self.leader = Robot(device_name=self.leader_port) if not mock else MockRobot
|
||||
self.leader.set_trigger_torque(trigger_torque)
|
||||
|
||||
# Initialize the cameras - sorted by camera names
|
||||
self.cameras = {}
|
||||
for cn, p in sorted(self.cameras_ports.items()):
|
||||
self.cameras[cn] = cv2.VideoCapture(p)
|
||||
self.cameras[cn] = cv2.VideoCapture(p) if not mock else MockCamera
|
||||
if not self.cameras[cn].isOpened():
|
||||
raise OSError(
|
||||
f"Cannot open camera port {p} for {cn}."
|
||||
@@ -118,7 +128,6 @@ class RealEnv(gym.Env):
|
||||
|
||||
self._observation = {}
|
||||
self._terminated = False
|
||||
self.starting_time = time.time()
|
||||
self.timestamps = []
|
||||
|
||||
def _get_obs(self):
|
||||
@@ -146,13 +155,8 @@ class RealEnv(gym.Env):
|
||||
if self.timestamps:
|
||||
# wait the right amount of time to stay at the desired fps
|
||||
time.sleep(max(0, 1 / self.fps - (time.time() - self.timestamps[-1])))
|
||||
recording_time = time.time() - self.starting_time
|
||||
else:
|
||||
# it's the first step so we start the timer
|
||||
self.starting_time = time.time()
|
||||
recording_time = 0
|
||||
|
||||
self.timestamps.append(recording_time)
|
||||
self.timestamps.append(time.time())
|
||||
|
||||
# Get the observation
|
||||
self._get_obs()
|
||||
@@ -165,13 +169,15 @@ class RealEnv(gym.Env):
|
||||
|
||||
reward = 0
|
||||
terminated = truncated = self._terminated
|
||||
info = {"timestamp": recording_time, "fps_error": False}
|
||||
info = {"timestamp": self.timestamps[-1] - self.timestamps[0], "fps_error": False}
|
||||
|
||||
# Check if we are able to keep up with the desired fps
|
||||
if recording_time - self.timestamps[-1] > 1 / (self.fps - self.fps_tolerance):
|
||||
if len(self.timestamps) > 1 and (self.timestamps[-1] - self.timestamps[-2]) > 1 / (
|
||||
self.fps - self.fps_tolerance
|
||||
):
|
||||
print(
|
||||
f"Error: recording time interval {recording_time - self.timestamps[-1]:.2f} is greater"
|
||||
f"than expected {1 / (self.fps - self.fps_tolerance):.2f}"
|
||||
f"Error: recording fps {1 / (self.timestamps[-1] - self.timestamps[-2]):.5f} is lower"
|
||||
f" than min admited fps {(self.fps - self.fps_tolerance):.5f}"
|
||||
f" at frame {len(self.timestamps)}"
|
||||
)
|
||||
info["fps_error"] = True
|
||||
|
||||
Reference in New Issue
Block a user