add train and evals
This commit is contained in:
@@ -11,13 +11,13 @@ from .robot import Robot
|
||||
FPS = 30
|
||||
|
||||
CAMERAS_SHAPES = {
|
||||
"observation.images.high": (480, 640, 3),
|
||||
"observation.images.low": (480, 640, 3),
|
||||
"images.high": (480, 640, 3),
|
||||
"images.low": (480, 640, 3),
|
||||
}
|
||||
|
||||
CAMERAS_PORTS = {
|
||||
"observation.images.high": "/dev/video6",
|
||||
"observation.images.low": "/dev/video0",
|
||||
"images.high": "/dev/video6",
|
||||
"images.low": "/dev/video0",
|
||||
}
|
||||
|
||||
LEADER_PORT = "/dev/ttyACM1"
|
||||
@@ -52,6 +52,8 @@ class RealEnv(gym.Env):
|
||||
leader_port: str = LEADER_PORT,
|
||||
warmup_steps: int = 100,
|
||||
trigger_torque=70,
|
||||
fps: int = FPS,
|
||||
fps_tolerance: float = 0.1,
|
||||
):
|
||||
self.num_joints = num_joints
|
||||
self.cameras_shapes = cameras_shapes
|
||||
@@ -62,6 +64,8 @@ class RealEnv(gym.Env):
|
||||
self.follower_port = follower_port
|
||||
self.leader_port = leader_port
|
||||
self.record = record
|
||||
self.fps = fps
|
||||
self.fps_tolerance = fps_tolerance
|
||||
|
||||
# Initialize the robot
|
||||
self.follower = Robot(device_name=self.follower_port)
|
||||
@@ -72,10 +76,13 @@ class RealEnv(gym.Env):
|
||||
# Initialize the cameras - sorted by camera names
|
||||
self.cameras = {}
|
||||
for cn, p in sorted(self.cameras_ports.items()):
|
||||
assert cn.startswith("observation.images."), "Camera names must start with 'observation.images.'."
|
||||
self.cameras[cn] = cv2.VideoCapture(p)
|
||||
if not all(c.isOpened() for c in self.cameras.values()):
|
||||
raise OSError("Cannot open all camera ports.")
|
||||
if not self.cameras[cn].isOpened():
|
||||
raise OSError(
|
||||
f"Cannot open camera port {p} for {cn}."
|
||||
f" Make sure the camera is connected and the port is correct."
|
||||
f"Also check you are not spinning several instances of the same environment (eval.batch_size)"
|
||||
)
|
||||
|
||||
# Specify gym action and observation spaces
|
||||
observation_space = {}
|
||||
@@ -98,7 +105,7 @@ class RealEnv(gym.Env):
|
||||
if self.cameras_shapes:
|
||||
for cn, hwc_shape in self.cameras_shapes.items():
|
||||
# Assumes images are unsigned int8 in [0,255]
|
||||
observation_space[f"images.{cn}"] = spaces.Box(
|
||||
observation_space[cn] = spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
# height x width x channels (e.g. 480 x 640 x 3)
|
||||
@@ -111,22 +118,20 @@ class RealEnv(gym.Env):
|
||||
|
||||
self._observation = {}
|
||||
self._terminated = False
|
||||
self._action_time = time.time()
|
||||
self.starting_time = time.time()
|
||||
self.timestamps = []
|
||||
|
||||
def _get_obs(self):
|
||||
qpos = self.follower.read_position()
|
||||
self._observation["agent_pos"] = pwm2pos(qpos)
|
||||
for cn, c in self.cameras.items():
|
||||
self._observation[f"images.{cn}"] = capture_image(
|
||||
c, self.cameras_shapes[cn][1], self.cameras_shapes[cn][0]
|
||||
)
|
||||
self._observation[cn] = capture_image(c, self.cameras_shapes[cn][1], self.cameras_shapes[cn][0])
|
||||
|
||||
if self.record:
|
||||
leader_pos = self.leader.read_position()
|
||||
self._observation["leader_pos"] = pwm2pos(leader_pos)
|
||||
action = self.leader.read_position()
|
||||
self._observation["leader_pos"] = pwm2pos(action)
|
||||
|
||||
def reset(self, seed: int | None = None):
|
||||
del seed
|
||||
# Reset the robot and sync the leader and follower if we are recording
|
||||
for _ in range(self.warmup_steps):
|
||||
self._get_obs()
|
||||
@@ -134,10 +139,22 @@ class RealEnv(gym.Env):
|
||||
self.follower.set_goal_pos(pos2pwm(self._observation["leader_pos"]))
|
||||
self._terminated = False
|
||||
info = {}
|
||||
self.timestamps = []
|
||||
return self._observation, info
|
||||
|
||||
def step(self, action: np.ndarray = None):
|
||||
# Reset the observation
|
||||
if self.timestamps:
|
||||
# wait the right amount of time to stay at the desired fps
|
||||
time.sleep(max(0, 1 / self.fps - (time.time() - self.timestamps[-1])))
|
||||
recording_time = time.time() - self.starting_time
|
||||
else:
|
||||
# it's the first step so we start the timer
|
||||
self.starting_time = time.time()
|
||||
recording_time = 0
|
||||
|
||||
self.timestamps.append(recording_time)
|
||||
|
||||
# Get the observation
|
||||
self._get_obs()
|
||||
if self.record:
|
||||
# Teleoperate the leader
|
||||
@@ -145,9 +162,20 @@ class RealEnv(gym.Env):
|
||||
else:
|
||||
# Apply the action to the follower
|
||||
self.follower.set_goal_pos(pos2pwm(action))
|
||||
|
||||
reward = 0
|
||||
terminated = truncated = self._terminated
|
||||
info = {}
|
||||
info = {"timestamp": recording_time, "fps_error": False}
|
||||
|
||||
# Check if we are able to keep up with the desired fps
|
||||
if recording_time - self.timestamps[-1] > 1 / (self.fps - self.fps_tolerance):
|
||||
print(
|
||||
f"Error: recording time interval {recording_time - self.timestamps[-1]:.2f} is greater"
|
||||
f"than expected {1 / (self.fps - self.fps_tolerance):.2f}"
|
||||
f" at frame {len(self.timestamps)}"
|
||||
)
|
||||
info["fps_error"] = True
|
||||
|
||||
return self._observation, reward, terminated, truncated, info
|
||||
|
||||
def render(self): ...
|
||||
|
||||
Reference in New Issue
Block a user