restructure code
This commit is contained in:
105
lerobot_aloha/common/utils/data_utils.py
Normal file
105
lerobot_aloha/common/utils/data_utils.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import logging
|
||||
import time
|
||||
from pprint import pprint
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils.utils import log_say, has_method
|
||||
from common.utils.control_utils import init_keyboard_listener, stop_recording, record_episode
|
||||
|
||||
|
||||
def record(
|
||||
robot,
|
||||
cfg
|
||||
) -> LeRobotDataset:
|
||||
"""
|
||||
Record robot data according to configuration.
|
||||
|
||||
Args:
|
||||
robot: Robot instance
|
||||
cfg: Configuration object
|
||||
|
||||
Returns:
|
||||
LeRobotDataset: Dataset with recorded episodes
|
||||
"""
|
||||
# Initialize or load dataset
|
||||
if cfg.resume:
|
||||
dataset = LeRobotDataset(
|
||||
cfg.repo_id,
|
||||
root=cfg.root,
|
||||
)
|
||||
if len(robot.cameras) > 0:
|
||||
dataset.start_image_writer(
|
||||
num_processes=cfg.num_image_writer_processes,
|
||||
num_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
)
|
||||
else:
|
||||
# Create empty dataset or load existing saved episodes
|
||||
dataset = LeRobotDataset.create(
|
||||
cfg.repo_id,
|
||||
cfg.fps,
|
||||
root=cfg.root,
|
||||
robot=None,
|
||||
features=robot.features,
|
||||
use_videos=cfg.video,
|
||||
image_writer_processes=cfg.num_image_writer_processes,
|
||||
image_writer_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
)
|
||||
|
||||
# Load pretrained policy
|
||||
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
|
||||
# Initialize keyboard listener
|
||||
listener, events = init_keyboard_listener()
|
||||
|
||||
# Print recording instructions
|
||||
print()
|
||||
print(f"开始记录轨迹,共需要记录{cfg.num_episodes}条\n每条轨迹的最长时间为{cfg.episode_time_s}frame\n按右方向键代表当前轨迹结束录制\n按上方面键代表当前轨迹开始录制\n按左方向键代表当前轨迹重新录制\n按ESC方向键代表退出轨迹录制\n")
|
||||
|
||||
# Record episodes
|
||||
recorded_episodes = 0
|
||||
while True:
|
||||
if recorded_episodes >= cfg.num_episodes:
|
||||
break
|
||||
|
||||
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
|
||||
pprint(f"Recording episode {dataset.num_episodes}, total episodes is {cfg.num_episodes}")
|
||||
record_episode(
|
||||
robot=robot,
|
||||
dataset=dataset,
|
||||
events=events,
|
||||
episode_time_s=cfg.episode_time_s,
|
||||
display_cameras=cfg.display_cameras,
|
||||
policy=policy,
|
||||
fps=cfg.fps,
|
||||
single_task=cfg.single_task,
|
||||
)
|
||||
|
||||
# Skip reset for the last episode to be recorded
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < cfg.num_episodes - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment", cfg.play_sounds)
|
||||
pprint("Reset the environment, stop recording")
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode", cfg.play_sounds)
|
||||
pprint("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
|
||||
if events["stop_recording"]:
|
||||
break
|
||||
|
||||
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
||||
stop_recording(robot, listener, cfg.display_cameras)
|
||||
|
||||
if cfg.push_to_hub:
|
||||
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)
|
||||
|
||||
log_say("Exiting", cfg.play_sounds)
|
||||
return dataset
|
||||
Reference in New Issue
Block a user