fix unit test

This commit is contained in:
Remi Cadene
2024-07-10 14:05:58 +02:00
parent 52e760a88e
commit 68a561570c
2 changed files with 16 additions and 2 deletions

View File

@@ -3,6 +3,7 @@ from dataclasses import dataclass, field, replace
from pathlib import Path
import time
import einops
import numpy as np
import torch
@@ -452,6 +453,7 @@ class KochRobot:
return obs_dict, action_dict
def capture_observation(self):
"""The returned observations do not have a batch dimension."""
if not self.is_connected:
raise RobotDeviceNotConnectedError(f"KochRobot is not connected. You need to run `robot.connect()`.")
@@ -476,13 +478,18 @@ class KochRobot:
obs_dict = {}
obs_dict["observation.state"] = torch.from_numpy(state)
for name in self.cameras:
obs_dict[f"observation.images.{name}"] = torch.from_numpy(images[name])
# Convert to pytorch format: channel first and float32 in [0,1]
img = torch.from_numpy(images[name])
img = img.type(torch.float32) / 255
img = img.permute(2, 0, 1).contiguous()
obs_dict[f"observation.images.{name}"] = img
return obs_dict
def send_action(self, action: torch.Tensor):
"""The provided action is expected to be a vector."""
if not self.is_connected:
raise RobotDeviceNotConnectedError(f"KochRobot is not connected. You need to run `robot.connect()`.")
from_idx = 0
to_idx = 0
follower_goal_pos = {}