updated image_keys and state_keys mechanisms + added seed to the dataset in order to restore the simulation state
This commit is contained in:
@@ -242,6 +242,7 @@ def create_rl_hf_dataset(data_dict):
|
|||||||
)
|
)
|
||||||
features["reward"] = Value(dtype="float32", id=None)
|
features["reward"] = Value(dtype="float32", id=None)
|
||||||
|
|
||||||
|
features["seed"] = Value(dtype="int64", id=None)
|
||||||
features["episode_index"] = Value(dtype="int64", id=None)
|
features["episode_index"] = Value(dtype="int64", id=None)
|
||||||
features["frame_index"] = Value(dtype="int64", id=None)
|
features["frame_index"] = Value(dtype="int64", id=None)
|
||||||
features["timestamp"] = Value(dtype="float32", id=None)
|
features["timestamp"] = Value(dtype="float32", id=None)
|
||||||
@@ -360,24 +361,29 @@ def record(
|
|||||||
show_images = multiprocessing.Process(target=show_image_observations, args=(observations_queue, ))
|
show_images = multiprocessing.Process(target=show_image_observations, args=(observations_queue, ))
|
||||||
show_images.start()
|
show_images.start()
|
||||||
|
|
||||||
|
state_keys_dict = env_cfg.state_keys
|
||||||
|
image_keys = env_cfg.image_keys
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writers) as executor:
|
with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writers) as executor:
|
||||||
# Start recording all episodes
|
# Start recording all episodes
|
||||||
# start reading from leader, disable stop flag in leader process
|
# start reading from leader, disable stop flag in leader process
|
||||||
while episode_index < num_episodes:
|
while episode_index < num_episodes:
|
||||||
logging.info(f"Recording episode {episode_index}")
|
logging.info(f"Recording episode {episode_index}")
|
||||||
say(f"Recording episode {episode_index}")
|
say(f"Recording episode {episode_index}")
|
||||||
ep_dict = {'action':[], 'observation.state': [], 'reward':[]}
|
ep_dict = {'action':[], 'reward':[]}
|
||||||
|
for k in state_keys_dict:
|
||||||
|
ep_dict[k] = []
|
||||||
frame_index = 0
|
frame_index = 0
|
||||||
timestamp = 0
|
timestamp = 0
|
||||||
start_episode_t = time.perf_counter()
|
start_episode_t = time.perf_counter()
|
||||||
observation, info = env.reset()
|
|
||||||
|
# save seed so we can restore the environment state when we want to replay the trajectories
|
||||||
|
seed = np.random.randint(0,1e5)
|
||||||
|
observation, info = env.reset(seed=seed)
|
||||||
#with stop_reading_leader.get_lock():
|
#with stop_reading_leader.get_lock():
|
||||||
#stop_reading_leader.Value = 0
|
#stop_reading_leader.Value = 0
|
||||||
read_leader.start()
|
read_leader.start()
|
||||||
while timestamp < episode_time_s:
|
while timestamp < episode_time_s:
|
||||||
action = command_queue.get()
|
action = command_queue.get()
|
||||||
image_keys = [key for key in observation if "image" in key]
|
|
||||||
state_keys = [key for key in observation if "image" not in key]
|
|
||||||
for key in image_keys:
|
for key in image_keys:
|
||||||
str_key = key if key.startswith('observation.images.') else 'observation.images.' + key
|
str_key = key if key.startswith('observation.images.') else 'observation.images.' + key
|
||||||
futures += [
|
futures += [
|
||||||
@@ -388,16 +394,14 @@ def record(
|
|||||||
if not is_headless() and visualize_images:
|
if not is_headless() and visualize_images:
|
||||||
observations_queue.put(observation)
|
observations_queue.put(observation)
|
||||||
|
|
||||||
state_obs = []
|
for key, obs_key in state_keys_dict.items():
|
||||||
for key in state_keys:
|
ep_dict[key].append(torch.from_numpy(observation[obs_key]))
|
||||||
state_obs.append(torch.from_numpy(observation[key]))
|
|
||||||
ep_dict['observation.state'].append(torch.hstack(state_obs) * 180.0 / np.pi)
|
|
||||||
|
|
||||||
# Advance the sim environment
|
# Advance the sim environment
|
||||||
if len(action.shape) == 1:
|
if len(action.shape) == 1:
|
||||||
action = np.expand_dims(action, 0)
|
action = np.expand_dims(action, 0)
|
||||||
observation, reward, _, _ , info = env.step(action)
|
observation, reward, _, _ , info = env.step(action)
|
||||||
ep_dict['action'].append(torch.from_numpy(action) * 180.0 / np.pi)
|
ep_dict['action'].append(torch.from_numpy(action))
|
||||||
ep_dict['reward'].append(torch.tensor(reward))
|
ep_dict['reward'].append(torch.tensor(reward))
|
||||||
print(reward)
|
print(reward)
|
||||||
|
|
||||||
@@ -444,10 +448,12 @@ def record(
|
|||||||
img_path = imgs_dir / f"frame_{i:06d}.png"
|
img_path = imgs_dir / f"frame_{i:06d}.png"
|
||||||
ep_dict[key].append({"path": str(img_path)})
|
ep_dict[key].append({"path": str(img_path)})
|
||||||
|
|
||||||
ep_dict['observation.state'] = torch.vstack(ep_dict['observation.state'])
|
for key in state_keys_dict:
|
||||||
ep_dict['action'] = torch.vstack(ep_dict['action'])
|
ep_dict[key] = torch.vstack(ep_dict[key]) * 180.0 / np.pi
|
||||||
|
ep_dict['action'] = torch.vstack(ep_dict['action']) * 180.0 / np.pi
|
||||||
ep_dict['reward'] = torch.stack(ep_dict['reward'])
|
ep_dict['reward'] = torch.stack(ep_dict['reward'])
|
||||||
|
|
||||||
|
ep_dict["seed"] = torch.tensor([seed] * num_frames)
|
||||||
ep_dict["episode_index"] = torch.tensor([episode_index] * num_frames)
|
ep_dict["episode_index"] = torch.tensor([episode_index] * num_frames)
|
||||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||||
@@ -489,7 +495,8 @@ def record(
|
|||||||
observations_queue.close()
|
observations_queue.close()
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
print('Waiting for ten seconds before starting the next recording session.....')
|
print('Waiting for two seconds before starting the next recording session.....')
|
||||||
|
busy_wait(2)
|
||||||
|
|
||||||
|
|
||||||
num_episodes = episode_index
|
num_episodes = episode_index
|
||||||
@@ -579,10 +586,11 @@ def replay(env, episodes: list, fps: int | None = None, root="data", repo_id="le
|
|||||||
|
|
||||||
dataset = LeRobotDataset(repo_id, root=root)
|
dataset = LeRobotDataset(repo_id, root=root)
|
||||||
items = dataset.hf_dataset.select_columns("action")
|
items = dataset.hf_dataset.select_columns("action")
|
||||||
|
seeds = dataset.hf_dataset.select_columns("seed")['seed']
|
||||||
for episode in episodes:
|
for episode in episodes:
|
||||||
env.reset()
|
|
||||||
from_idx = dataset.episode_data_index["from"][episode].item()
|
from_idx = dataset.episode_data_index["from"][episode].item()
|
||||||
to_idx = dataset.episode_data_index["to"][episode].item()
|
to_idx = dataset.episode_data_index["to"][episode].item()
|
||||||
|
env.reset(seed=seeds[from_idx].item())
|
||||||
|
|
||||||
logging.info("Replaying episode")
|
logging.info("Replaying episode")
|
||||||
say("Replaying episode", blocking=True)
|
say("Replaying episode", blocking=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user