updated image_keys and state_keys mechanisms + added seed to the dataset in order to restore the simulation state

This commit is contained in:
Michel Aractingi
2024-10-22 12:00:25 +02:00
parent 8b17416fc7
commit 04029f5e74

View File

@@ -242,6 +242,7 @@ def create_rl_hf_dataset(data_dict):
)
features["reward"] = Value(dtype="float32", id=None)
features["seed"] = Value(dtype="int64", id=None)
features["episode_index"] = Value(dtype="int64", id=None)
features["frame_index"] = Value(dtype="int64", id=None)
features["timestamp"] = Value(dtype="float32", id=None)
@@ -360,24 +361,29 @@ def record(
show_images = multiprocessing.Process(target=show_image_observations, args=(observations_queue, ))
show_images.start()
state_keys_dict = env_cfg.state_keys
image_keys = env_cfg.image_keys
with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writers) as executor:
# Start recording all episodes
# start reading from leader, disable stop flag in leader process
while episode_index < num_episodes:
logging.info(f"Recording episode {episode_index}")
say(f"Recording episode {episode_index}")
ep_dict = {'action':[], 'observation.state': [], 'reward':[]}
ep_dict = {'action':[], 'reward':[]}
for k in state_keys_dict:
ep_dict[k] = []
frame_index = 0
timestamp = 0
start_episode_t = time.perf_counter()
observation, info = env.reset()
# save seed so we can restore the environment state when we want to replay the trajectories
seed = np.random.randint(0,1e5)
observation, info = env.reset(seed=seed)
#with stop_reading_leader.get_lock():
#stop_reading_leader.Value = 0
read_leader.start()
while timestamp < episode_time_s:
action = command_queue.get()
image_keys = [key for key in observation if "image" in key]
state_keys = [key for key in observation if "image" not in key]
for key in image_keys:
str_key = key if key.startswith('observation.images.') else 'observation.images.' + key
futures += [
@@ -388,16 +394,14 @@ def record(
if not is_headless() and visualize_images:
observations_queue.put(observation)
state_obs = []
for key in state_keys:
state_obs.append(torch.from_numpy(observation[key]))
ep_dict['observation.state'].append(torch.hstack(state_obs) * 180.0 / np.pi)
for key, obs_key in state_keys_dict.items():
ep_dict[key].append(torch.from_numpy(observation[obs_key]))
# Advance the sim environment
if len(action.shape) == 1:
action = np.expand_dims(action, 0)
observation, reward, _, _ , info = env.step(action)
ep_dict['action'].append(torch.from_numpy(action) * 180.0 / np.pi)
ep_dict['action'].append(torch.from_numpy(action))
ep_dict['reward'].append(torch.tensor(reward))
print(reward)
@@ -444,10 +448,12 @@ def record(
img_path = imgs_dir / f"frame_{i:06d}.png"
ep_dict[key].append({"path": str(img_path)})
ep_dict['observation.state'] = torch.vstack(ep_dict['observation.state'])
ep_dict['action'] = torch.vstack(ep_dict['action'])
for key in state_keys_dict:
ep_dict[key] = torch.vstack(ep_dict[key]) * 180.0 / np.pi
ep_dict['action'] = torch.vstack(ep_dict['action']) * 180.0 / np.pi
ep_dict['reward'] = torch.stack(ep_dict['reward'])
ep_dict["seed"] = torch.tensor([seed] * num_frames)
ep_dict["episode_index"] = torch.tensor([episode_index] * num_frames)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
@@ -489,7 +495,8 @@ def record(
observations_queue.close()
break
else:
print('Waiting for ten seconds before starting the next recording session.....')
print('Waiting for two seconds before starting the next recording session.....')
busy_wait(2)
num_episodes = episode_index
@@ -579,10 +586,11 @@ def replay(env, episodes: list, fps: int | None = None, root="data", repo_id="le
dataset = LeRobotDataset(repo_id, root=root)
items = dataset.hf_dataset.select_columns("action")
seeds = dataset.hf_dataset.select_columns("seed")['seed']
for episode in episodes:
env.reset()
from_idx = dataset.episode_data_index["from"][episode].item()
to_idx = dataset.episode_data_index["to"][episode].item()
env.reset(seed=seeds[from_idx].item())
logging.info("Replaying episode")
say("Replaying episode", blocking=True)