fix unit test
This commit is contained in:
@@ -3,6 +3,7 @@ from dataclasses import dataclass, field, replace
|
||||
from pathlib import Path
|
||||
import time
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@@ -452,6 +453,7 @@ class KochRobot:
|
||||
return obs_dict, action_dict
|
||||
|
||||
def capture_observation(self):
|
||||
"""The returned observations do not have a batch dimension."""
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(f"KochRobot is not connected. You need to run `robot.connect()`.")
|
||||
|
||||
@@ -476,13 +478,18 @@ class KochRobot:
|
||||
obs_dict = {}
|
||||
obs_dict["observation.state"] = torch.from_numpy(state)
|
||||
for name in self.cameras:
|
||||
obs_dict[f"observation.images.{name}"] = torch.from_numpy(images[name])
|
||||
# Convert to pytorch format: channel first and float32 in [0,1]
|
||||
img = torch.from_numpy(images[name])
|
||||
img = img.type(torch.float32) / 255
|
||||
img = img.permute(2, 0, 1).contiguous()
|
||||
obs_dict[f"observation.images.{name}"] = img
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: torch.Tensor):
|
||||
"""The provided action is expected to be a vector."""
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(f"KochRobot is not connected. You need to run `robot.connect()`.")
|
||||
|
||||
|
||||
from_idx = 0
|
||||
to_idx = 0
|
||||
follower_goal_pos = {}
|
||||
|
||||
Reference in New Issue
Block a user