当前代码有问题
This commit is contained in:
@@ -1,5 +1,18 @@
|
||||
export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtiff.so.5
|
||||
|
||||
python collect_data.py --robot.type=aloha --control.type=record --control.fps=30 --control.single_task="Grasp a lego block and put it in the bin." --control.repo_id=tangger/test --control.num_episodes=1 --control.root=./data
|
||||
# fd token
|
||||
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
|
||||
|
||||
python lerobot/scripts/train.py --dataset.repo_id=maic/move_tube_on_scale --policy.type=act --output_dir=outputs/train/act_move_tube_on_scale --job_name=act_move_tube_on_scale --policy.device=cuda --wandb.enable=true --dataset.root=/home/ubuntu/LYT/aloha_lerobot/data1
|
||||
# act
|
||||
python lerobot/lerobot/scripts/train.py --policy.type=act --policy.device=cuda --wandb.enable=true --dataset.root=/home/ubuntu/LYT/lerobot_aloha/datasets/move_a_reagent_bottle_on_the_scale_with_head_camera/ --dataset.repo_id=maic/move_tube_on_scale_head --job_name=act_with_head --output_dir=outputs/train/act_move_bottle_on_scale_with_head
|
||||
|
||||
python lerobot/lerobot/scripts/visualize_dataset_html.py --root /home/ubuntu/LYT/lerobot_aloha/datasets/move_a_reagent_bottle_on_a_scale_without_head_camera --repo-id xxx
|
||||
|
||||
# pi0 ft
|
||||
python lerobot/lerobot/scripts/train.py \
|
||||
--policy.path=lerobot/pi0 \
|
||||
--wandb.enable=true \
|
||||
--dataset.root=/home/ubuntu/LYT/lerobot_aloha/datasets/move_a_reagent_bottle_on_a_scale_without_head_camera \
|
||||
--dataset.repo_id=maic/move_a_reagent_bottle_on_a_scale_without_head_camera \
|
||||
--job_name=pi0_without_head \
|
||||
--output_dir=outputs/train/move_a_reagent_bottle_on_a_scale_without_head_camera
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -42,7 +42,9 @@ class AgilexRobot(Robot):
|
||||
if arm_name not in self.sync_arm_queues or len(self.sync_arm_queues[arm_name]) == 0:
|
||||
print(f"can not get data from {arm_name} topic")
|
||||
return None
|
||||
|
||||
|
||||
# 时间戳误差
|
||||
tolerance = 0.1 # 允许 100ms 的时间戳偏差
|
||||
# 计算最小时间戳
|
||||
timestamps = [
|
||||
q[-1].header.stamp.to_sec()
|
||||
@@ -58,18 +60,16 @@ class AgilexRobot(Robot):
|
||||
|
||||
min_time = min(timestamps)
|
||||
|
||||
# 检查数据同步性
|
||||
# 检查数据同步性(允许 100ms 偏差)
|
||||
for queue in list(self.sync_img_queues.values()) + list(self.sync_arm_queues.values()):
|
||||
if queue[-1].header.stamp.to_sec() < min_time:
|
||||
if queue[-1].header.stamp.to_sec() < min_time - tolerance:
|
||||
return None
|
||||
|
||||
if self.use_depth_image:
|
||||
for queue in self.sync_depth_queues.values():
|
||||
if queue[-1].header.stamp.to_sec() < min_time:
|
||||
if queue[-1].header.stamp.to_sec() < min_time - tolerance:
|
||||
return None
|
||||
|
||||
if self.use_robot_base and len(self.sync_base_queue) > 0:
|
||||
if self.sync_base_queue[-1].header.stamp.to_sec() < min_time:
|
||||
if self.sync_base_queue[-1].header.stamp.to_sec() < min_time - tolerance:
|
||||
return None
|
||||
|
||||
# 提取同步数据
|
||||
@@ -81,33 +81,35 @@ class AgilexRobot(Robot):
|
||||
|
||||
# 图像数据
|
||||
for cam_name, queue in self.sync_img_queues.items():
|
||||
while queue[0].header.stamp.to_sec() < min_time:
|
||||
while queue and queue[0].header.stamp.to_sec() < min_time - tolerance:
|
||||
queue.popleft()
|
||||
frame_data['images'][cam_name] = self.bridge.imgmsg_to_cv2(queue.popleft(), 'passthrough')
|
||||
|
||||
if queue:
|
||||
frame_data['images'][cam_name] = self.bridge.imgmsg_to_cv2(queue.popleft(), 'passthrough')
|
||||
|
||||
# 深度数据
|
||||
if self.use_depth_image:
|
||||
frame_data['depths'] = {}
|
||||
for cam_name, queue in self.sync_depth_queues.items():
|
||||
while queue[0].header.stamp.to_sec() < min_time:
|
||||
while queue and queue[0].header.stamp.to_sec() < min_time - tolerance:
|
||||
queue.popleft()
|
||||
depth_img = self.bridge.imgmsg_to_cv2(queue.popleft(), 'passthrough')
|
||||
# 保持原有的边界填充
|
||||
frame_data['depths'][cam_name] = cv2.copyMakeBorder(
|
||||
depth_img, 40, 40, 0, 0, cv2.BORDER_CONSTANT, value=0
|
||||
)
|
||||
|
||||
if queue:
|
||||
depth_img = self.bridge.imgmsg_to_cv2(queue.popleft(), 'passthrough')
|
||||
frame_data['depths'][cam_name] = cv2.copyMakeBorder(
|
||||
depth_img, 40, 40, 0, 0, cv2.BORDER_CONSTANT, value=0
|
||||
)
|
||||
|
||||
# 机械臂数据
|
||||
for arm_name, queue in self.sync_arm_queues.items():
|
||||
while queue[0].header.stamp.to_sec() < min_time:
|
||||
while queue and queue[0].header.stamp.to_sec() < min_time - tolerance:
|
||||
queue.popleft()
|
||||
frame_data['arms'][arm_name] = queue.popleft()
|
||||
|
||||
if queue:
|
||||
frame_data['arms'][arm_name] = queue.popleft()
|
||||
|
||||
# 基座数据
|
||||
if self.use_robot_base and len(self.sync_base_queue) > 0:
|
||||
while self.sync_base_queue[0].header.stamp.to_sec() < min_time:
|
||||
while self.sync_base_queue and self.sync_base_queue[0].header.stamp.to_sec() < min_time - tolerance:
|
||||
self.sync_base_queue.popleft()
|
||||
frame_data['base'] = self.sync_base_queue.popleft()
|
||||
if self.sync_base_queue:
|
||||
frame_data['base'] = self.sync_base_queue.popleft()
|
||||
|
||||
return frame_data
|
||||
|
||||
@@ -210,11 +212,11 @@ class AgilexRobot(Robot):
|
||||
if arm_states:
|
||||
obs_dict["observation.state"] = torch.tensor(np.concatenate(arm_states).reshape(-1)) # 先转Python列表
|
||||
|
||||
if arm_velocity:
|
||||
obs_dict["observation.velocity"] = torch.tensor(np.concatenate(arm_velocity).reshape(-1))
|
||||
# if arm_velocity:
|
||||
# obs_dict["observation.velocity"] = torch.tensor(np.concatenate(arm_velocity).reshape(-1))
|
||||
|
||||
if arm_effort:
|
||||
obs_dict["observation.effort"] = torch.tensor(np.concatenate(arm_effort).reshape(-1))
|
||||
# if arm_effort:
|
||||
# obs_dict["observation.effort"] = torch.tensor(np.concatenate(arm_effort).reshape(-1))
|
||||
|
||||
if actions:
|
||||
action_dict["action"] = torch.tensor(np.concatenate(actions).reshape(-1))
|
||||
@@ -276,7 +278,7 @@ class AgilexRobot(Robot):
|
||||
|
||||
# Log timing information
|
||||
# self.logs[f"read_arm_{arm_name}_pos_dt_s"] = time.perf_counter() - before_read_t
|
||||
print(f"read_arm_{arm_name}_pos_dt_s is", time.perf_counter() - before_read_t)
|
||||
# print(f"read_arm_{arm_name}_pos_dt_s is", time.perf_counter() - before_read_t)
|
||||
|
||||
# Combine all arm states into single tensor
|
||||
if arm_states:
|
||||
@@ -299,7 +301,7 @@ class AgilexRobot(Robot):
|
||||
|
||||
# Log timing information
|
||||
# self.logs[f"read_camera_{cam_name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
print(f"read_camera_{cam_name}_dt_s is", time.perf_counter() - before_camread_t)
|
||||
# print(f"read_camera_{cam_name}_dt_s is", time.perf_counter() - before_camread_t)
|
||||
|
||||
# Process depth data if enabled
|
||||
if self.use_depth_image and 'depths' in frame_data:
|
||||
@@ -311,7 +313,7 @@ class AgilexRobot(Robot):
|
||||
obs_dict[f"observation.images.depth_{cam_name}"] = depth_tensor
|
||||
|
||||
# self.logs[f"read_depth_{cam_name}_dt_s"] = time.perf_counter() - before_depthread_t
|
||||
print(f"read_depth_{cam_name}_dt_s is", time.perf_counter() - before_depthread_t)
|
||||
# print(f"read_depth_{cam_name}_dt_s is", time.perf_counter() - before_depthread_t)
|
||||
|
||||
# Process base velocity if enabled
|
||||
if self.use_robot_base and 'base' in frame_data:
|
||||
@@ -341,8 +343,8 @@ class AgilexRobot(Robot):
|
||||
-0.03296661376953125, -0.010990142822265625, -0.010990142822265625,
|
||||
-0.03296661376953125, -0.03296661376953125]
|
||||
|
||||
last_effort = [-0.021978378295898438, 0.2417583465576172, 4.320878982543945,
|
||||
3.6527481079101562, -0.013187408447265625, -0.013187408447265625,
|
||||
last_effort = [-0.021978378295898438, 0.2417583465576172, 0.320878982543945,
|
||||
0.6527481079101562, -0.013187408447265625, -0.013187408447265625,
|
||||
0.0, -0.010990142822265625, -0.010990142822265625,
|
||||
-0.03296661376953125, -0.010990142822265625, -0.010990142822265625,
|
||||
-0.03296661376953125, -0.03296661376953125]
|
||||
@@ -369,22 +371,23 @@ class AgilexRobot(Robot):
|
||||
arm_velocity = last_velocity[from_idx:to_idx]
|
||||
arm_effort = last_effort[from_idx:to_idx]
|
||||
from_idx = to_idx
|
||||
|
||||
# fix
|
||||
arm_action[-1] = max(arm_action[-1]*15, 0)
|
||||
|
||||
# Apply safety checks if configured
|
||||
# if 'max_relative_target' in self.config:
|
||||
# # Get current position from the queue
|
||||
# if len(self.sync_arm_queues[arm_name]) > 0:
|
||||
# current_state = self.sync_arm_queues[arm_name][-1]
|
||||
# current_pos = np.array(current_state.position)
|
||||
|
||||
# # Clip the action to stay within max relative target
|
||||
# max_delta = self.config['max_relative_target']
|
||||
# clipped_action = np.clip(arm_action,
|
||||
# current_pos - max_delta,
|
||||
# current_pos + max_delta)
|
||||
# arm_action = clipped_action
|
||||
|
||||
# # Get current position from the queue
|
||||
# if len(arm_action) > 0:
|
||||
|
||||
# # Clip the action to stay within max relative target
|
||||
# max_delta = 0.1
|
||||
# clipped_action = np.clip(arm_action,
|
||||
# arm_action - max_delta,
|
||||
# arm_action + max_delta)
|
||||
# arm_action = clipped_action
|
||||
|
||||
action_sent.append(arm_action)
|
||||
# action_sent.append(arm_action)
|
||||
|
||||
# Create and publish JointState message
|
||||
joint_state = JointState()
|
||||
|
||||
@@ -323,21 +323,21 @@ class RobotActuators:
|
||||
"names": {"motors": state.get('motors', "")}
|
||||
}
|
||||
|
||||
if self.config.get('velocity'):
|
||||
velocity = self.config.get('velocity', "")
|
||||
features["observation.velocity"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(velocity.get('motors', "")),),
|
||||
"names": {"motors": velocity.get('motors', "")}
|
||||
}
|
||||
# if self.config.get('velocity'):
|
||||
# velocity = self.config.get('velocity', "")
|
||||
# features["observation.velocity"] = {
|
||||
# "dtype": "float32",
|
||||
# "shape": (len(velocity.get('motors', "")),),
|
||||
# "names": {"motors": velocity.get('motors', "")}
|
||||
# }
|
||||
|
||||
if self.config.get('effort'):
|
||||
effort = self.config.get('effort', "")
|
||||
features["observation.effort"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(effort.get('motors', "")),),
|
||||
"names": {"motors": effort.get('motors', "")}
|
||||
}
|
||||
# if self.config.get('effort'):
|
||||
# effort = self.config.get('effort', "")
|
||||
# features["observation.effort"] = {
|
||||
# "dtype": "float32",
|
||||
# "shape": (len(effort.get('motors', "")),),
|
||||
# "names": {"motors": effort.get('motors', "")}
|
||||
# }
|
||||
|
||||
def _init_action_features(self, features: Dict[str, Any]) -> None:
|
||||
"""Initialize action features"""
|
||||
@@ -393,7 +393,7 @@ class RobotDataManager:
|
||||
# Check camera image queues
|
||||
rospy.loginfo(f"Nums of camera is {len(self.sensors.cameras)}")
|
||||
for cam_name in self.sensors.cameras:
|
||||
if len(self.sensors.sync_img_queues[cam_name]) < 50:
|
||||
if len(self.sensors.sync_img_queues[cam_name]) < 200:
|
||||
rospy.loginfo(f"Waiting for camera {cam_name} (current: {len(self.sensors.sync_img_queues[cam_name])}/50)")
|
||||
all_ready = False
|
||||
break
|
||||
@@ -401,7 +401,7 @@ class RobotDataManager:
|
||||
# Check depth queues if enabled
|
||||
if self.sensors.use_depth_image:
|
||||
for cam_name in self.sensors.sync_depth_queues:
|
||||
if len(self.sensors.sync_depth_queues[cam_name]) < 50:
|
||||
if len(self.sensors.sync_depth_queues[cam_name]) < 200:
|
||||
rospy.loginfo(f"Waiting for depth camera {cam_name} (current: {len(self.sensors.sync_depth_queues[cam_name])}/50)")
|
||||
all_ready = False
|
||||
break
|
||||
|
||||
Binary file not shown.
@@ -140,8 +140,8 @@ def control_loop(
|
||||
|
||||
if num_images > 0:
|
||||
# 设置每个图像的显示尺寸
|
||||
display_width = 426 # 更小的宽度
|
||||
display_height = 320 # 更小的高度
|
||||
display_width = 640 # 更小的宽度
|
||||
display_height = 480 # 更小的高度
|
||||
|
||||
# 确定网格布局的行列数 (尽量接近正方形布局)
|
||||
grid_cols = int(np.ceil(np.sqrt(num_images)))
|
||||
|
||||
@@ -2,17 +2,19 @@ robot_type: aloha_agilex
|
||||
ros_node_name: record_episodes
|
||||
cameras:
|
||||
cam_high:
|
||||
img_topic_name: /camera/color/image_raw
|
||||
depth_topic_name: /camera/depth/image_rect_raw
|
||||
rgb_shape: [480, 640, 3]
|
||||
width: 480
|
||||
height: 640
|
||||
cam_front:
|
||||
# img_topic_name: /camera/color/image_raw
|
||||
# depth_topic_name: /camera/depth/image_rect_raw
|
||||
img_topic_name: /camera_f/color/image_raw
|
||||
depth_topic_name: /camera_f/depth/image_raw
|
||||
rgb_shape: [480, 640, 3]
|
||||
width: 480
|
||||
height: 640
|
||||
rgb_shape: [480, 640, 3]
|
||||
# cam_front:
|
||||
# img_topic_name: /camera_f/color/image_raw
|
||||
# depth_topic_name: /camera_f/depth/image_raw
|
||||
# width: 480
|
||||
# height: 640
|
||||
# rgb_shape: [480, 640, 3]
|
||||
cam_left:
|
||||
img_topic_name: /camera_l/color/image_raw
|
||||
depth_topic_name: /camera_l/depth/image_raw
|
||||
@@ -92,41 +94,41 @@ state:
|
||||
"right_none"
|
||||
]
|
||||
|
||||
velocity:
|
||||
motors: [
|
||||
"left_joint0",
|
||||
"left_joint1",
|
||||
"left_joint2",
|
||||
"left_joint3",
|
||||
"left_joint4",
|
||||
"left_joint5",
|
||||
"left_none",
|
||||
"right_joint0",
|
||||
"right_joint1",
|
||||
"right_joint2",
|
||||
"right_joint3",
|
||||
"right_joint4",
|
||||
"right_joint5",
|
||||
"right_none"
|
||||
]
|
||||
# velocity:
|
||||
# motors: [
|
||||
# "left_joint0",
|
||||
# "left_joint1",
|
||||
# "left_joint2",
|
||||
# "left_joint3",
|
||||
# "left_joint4",
|
||||
# "left_joint5",
|
||||
# "left_none",
|
||||
# "right_joint0",
|
||||
# "right_joint1",
|
||||
# "right_joint2",
|
||||
# "right_joint3",
|
||||
# "right_joint4",
|
||||
# "right_joint5",
|
||||
# "right_none"
|
||||
# ]
|
||||
|
||||
effort:
|
||||
motors: [
|
||||
"left_joint0",
|
||||
"left_joint1",
|
||||
"left_joint2",
|
||||
"left_joint3",
|
||||
"left_joint4",
|
||||
"left_joint5",
|
||||
"left_none",
|
||||
"right_joint0",
|
||||
"right_joint1",
|
||||
"right_joint2",
|
||||
"right_joint3",
|
||||
"right_joint4",
|
||||
"right_joint5",
|
||||
"right_none"
|
||||
]
|
||||
# effort:
|
||||
# motors: [
|
||||
# "left_joint0",
|
||||
# "left_joint1",
|
||||
# "left_joint2",
|
||||
# "left_joint3",
|
||||
# "left_joint4",
|
||||
# "left_joint5",
|
||||
# "left_none",
|
||||
# "right_joint0",
|
||||
# "right_joint1",
|
||||
# "right_joint2",
|
||||
# "right_joint3",
|
||||
# "right_joint4",
|
||||
# "right_joint5",
|
||||
# "right_none"
|
||||
# ]
|
||||
|
||||
action:
|
||||
motors: [
|
||||
|
||||
@@ -1,36 +1,78 @@
|
||||
# coding=utf-8
|
||||
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
import time
|
||||
import argparse
|
||||
import rospy
|
||||
from common.rosrobot_factory import RobotFactory
|
||||
from common.utils.replay_utils import replay
|
||||
from common.agilex_robot import AgilexRobot
|
||||
import torch
|
||||
import cv2
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
|
||||
|
||||
def get_arguments():
|
||||
"""
|
||||
Parse command line arguments.
|
||||
|
||||
Returns:
|
||||
argparse.Namespace: Parsed arguments
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
args = parser.parse_args()
|
||||
args.repo_id = "tangger/test"
|
||||
args.root = "/home/ubuntu/LYT/aloha_lerobot/data1"
|
||||
args.episode = 1 # replay episode
|
||||
args.fps = 30
|
||||
args.resume = False
|
||||
args.repo_id = "tangger/test"
|
||||
# args.root = "/home/ubuntu/LYT/lerobot_aloha/datasets/move_a_tube_on_the_scale_without_front"
|
||||
# args.root="/home/ubuntu/LYT/aloha_lerobot/data4"
|
||||
args.root = "/home/ubuntu/LYT/lerobot_aloha/datasets/abcde"
|
||||
args.num_image_writer_processes = 0
|
||||
args.num_image_writer_threads_per_camera = 4
|
||||
args.video = True
|
||||
args.num_episodes = 50
|
||||
args.episode_time_s = 30000
|
||||
args.play_sounds = False
|
||||
args.display_cameras = True
|
||||
args.single_task = "test test"
|
||||
args.use_depth_image = False
|
||||
args.use_base = False
|
||||
args.push_to_hub = False
|
||||
args.policy= None
|
||||
args.teleoprate = False
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_arguments()
|
||||
|
||||
# Initialize ROS node
|
||||
rospy.init_node("replay_node")
|
||||
|
||||
# Create robot instance using factory pattern
|
||||
robot = RobotFactory.create(config_file="/home/ubuntu/LYT/lerobot_aloha/lerobot_aloha/configs/agilex.yaml", args=args)
|
||||
|
||||
# Replay the specified episode
|
||||
replay(robot, args)
|
||||
cfg = get_arguments()
|
||||
robot = AgilexRobot(config_file="/home/ubuntu/LYT/lerobot_aloha/lerobot_aloha/configs/agilex.yaml", args=cfg)
|
||||
inference_time_s = 360
|
||||
fps = 15
|
||||
device = "cuda" # TODO: On Mac, use "mps" or "cpu"
|
||||
|
||||
|
||||
dataset = LeRobotDataset(
|
||||
cfg.repo_id,
|
||||
root=cfg.root,
|
||||
)
|
||||
shuffle = True
|
||||
sampler = None
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
sampler=sampler,
|
||||
pin_memory=device != "cpu",
|
||||
drop_last=False,
|
||||
)
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
# 控制播放速度fps=30
|
||||
for data in dl_iter:
|
||||
start_time = time.perf_counter()
|
||||
action = data["action"]
|
||||
# cam_high = data["observation.images.cam_high"]
|
||||
# cam_left = data["observation.images.cam_left"]
|
||||
# cam_right = data["observation.images.cam_right"]
|
||||
# print(data)
|
||||
|
||||
# Remove batch dimension
|
||||
action = action.squeeze(0)
|
||||
# Move to cpu, if not already the case
|
||||
action = action.to("cpu")
|
||||
# Order the robot to move
|
||||
robot.send_action(action)
|
||||
print(action)
|
||||
dt_s = time.perf_counter() - start_time
|
||||
busy_wait(1 / fps - dt_s)
|
||||
@@ -1,9 +1,12 @@
|
||||
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
import time
|
||||
import argparse
|
||||
from common.agilex_robot import AgilexRobot
|
||||
import torch
|
||||
import cv2
|
||||
|
||||
|
||||
def get_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -34,8 +37,13 @@ inference_time_s = 360
|
||||
fps = 30
|
||||
device = "cuda" # TODO: On Mac, use "mps" or "cpu"
|
||||
|
||||
ckpt_path = "/home/ubuntu/LYT/lerobot_aloha/outputs/train/act_move_tube_on_scale/checkpoints/last/pretrained_model"
|
||||
policy = ACTPolicy.from_pretrained(ckpt_path)
|
||||
# ckpt_path = "/home/ubuntu/LYT/lerobot_aloha/outputs/train/act_move_bottle_on_scale_without_front/checkpoints/last/pretrained_model"
|
||||
ckpt_path ="/home/ubuntu/LYT/lerobot_aloha/outputs/train/act_abcde/checkpoints/last/pretrained_model"
|
||||
policy = ACTPolicy.from_pretrained(pretrained_name_or_path=ckpt_path)
|
||||
|
||||
# ckpt_path ="/home/ubuntu/LYT/lerobot_aloha/outputs/train/diffusion_abcde/checkpoints/020000/pretrained_model"
|
||||
# policy = DiffusionPolicy.from_pretrained(pretrained_name_or_path=ckpt_path)
|
||||
|
||||
policy.to(device)
|
||||
|
||||
for _ in range(inference_time_s * fps):
|
||||
@@ -46,16 +54,23 @@ for _ in range(inference_time_s * fps):
|
||||
if observation is None:
|
||||
print("Observation is None, skipping...")
|
||||
continue
|
||||
|
||||
# visualize the image in the obervation
|
||||
# cv2.imshow("observation", observation["observation.image"])
|
||||
|
||||
# Convert to pytorch format: channel first and float32 in [0,1]
|
||||
# with batch dimension
|
||||
for name in observation:
|
||||
if "image" in name:
|
||||
img = observation[name].numpy()
|
||||
# cv2.imshow(name, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
|
||||
# cv2.imwrite(f"{name}.png", cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||
observation[name] = observation[name].unsqueeze(0)
|
||||
observation[name] = observation[name].to(device)
|
||||
|
||||
last_pos = observation["observation.state"]
|
||||
# Compute the next action with the policy
|
||||
# based on the current observation
|
||||
action = policy.select_action(observation)
|
||||
@@ -65,6 +80,8 @@ for _ in range(inference_time_s * fps):
|
||||
action = action.to("cpu")
|
||||
# Order the robot to move
|
||||
robot.send_action(action)
|
||||
print("left pos:", action[:7])
|
||||
print("right pos:", action[7:])
|
||||
|
||||
dt_s = time.perf_counter() - start_time
|
||||
busy_wait(1 / fps - dt_s)
|
||||
Reference in New Issue
Block a user