updated image_keys and state_keys mechanisms + added seed to the dataset in order to restore the simulation state
This commit is contained in:
@@ -242,6 +242,7 @@ def create_rl_hf_dataset(data_dict):
|
||||
)
|
||||
features["reward"] = Value(dtype="float32", id=None)
|
||||
|
||||
features["seed"] = Value(dtype="int64", id=None)
|
||||
features["episode_index"] = Value(dtype="int64", id=None)
|
||||
features["frame_index"] = Value(dtype="int64", id=None)
|
||||
features["timestamp"] = Value(dtype="float32", id=None)
|
||||
@@ -360,24 +361,29 @@ def record(
|
||||
show_images = multiprocessing.Process(target=show_image_observations, args=(observations_queue, ))
|
||||
show_images.start()
|
||||
|
||||
state_keys_dict = env_cfg.state_keys
|
||||
image_keys = env_cfg.image_keys
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writers) as executor:
|
||||
# Start recording all episodes
|
||||
# start reading from leader, disable stop flag in leader process
|
||||
while episode_index < num_episodes:
|
||||
logging.info(f"Recording episode {episode_index}")
|
||||
say(f"Recording episode {episode_index}")
|
||||
ep_dict = {'action':[], 'observation.state': [], 'reward':[]}
|
||||
ep_dict = {'action':[], 'reward':[]}
|
||||
for k in state_keys_dict:
|
||||
ep_dict[k] = []
|
||||
frame_index = 0
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
observation, info = env.reset()
|
||||
|
||||
# save seed so we can restore the environment state when we want to replay the trajectories
|
||||
seed = np.random.randint(0,1e5)
|
||||
observation, info = env.reset(seed=seed)
|
||||
#with stop_reading_leader.get_lock():
|
||||
#stop_reading_leader.Value = 0
|
||||
read_leader.start()
|
||||
while timestamp < episode_time_s:
|
||||
action = command_queue.get()
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
state_keys = [key for key in observation if "image" not in key]
|
||||
for key in image_keys:
|
||||
str_key = key if key.startswith('observation.images.') else 'observation.images.' + key
|
||||
futures += [
|
||||
@@ -388,16 +394,14 @@ def record(
|
||||
if not is_headless() and visualize_images:
|
||||
observations_queue.put(observation)
|
||||
|
||||
state_obs = []
|
||||
for key in state_keys:
|
||||
state_obs.append(torch.from_numpy(observation[key]))
|
||||
ep_dict['observation.state'].append(torch.hstack(state_obs) * 180.0 / np.pi)
|
||||
for key, obs_key in state_keys_dict.items():
|
||||
ep_dict[key].append(torch.from_numpy(observation[obs_key]))
|
||||
|
||||
# Advance the sim environment
|
||||
if len(action.shape) == 1:
|
||||
action = np.expand_dims(action, 0)
|
||||
observation, reward, _, _ , info = env.step(action)
|
||||
ep_dict['action'].append(torch.from_numpy(action) * 180.0 / np.pi)
|
||||
ep_dict['action'].append(torch.from_numpy(action))
|
||||
ep_dict['reward'].append(torch.tensor(reward))
|
||||
print(reward)
|
||||
|
||||
@@ -444,10 +448,12 @@ def record(
|
||||
img_path = imgs_dir / f"frame_{i:06d}.png"
|
||||
ep_dict[key].append({"path": str(img_path)})
|
||||
|
||||
ep_dict['observation.state'] = torch.vstack(ep_dict['observation.state'])
|
||||
ep_dict['action'] = torch.vstack(ep_dict['action'])
|
||||
for key in state_keys_dict:
|
||||
ep_dict[key] = torch.vstack(ep_dict[key]) * 180.0 / np.pi
|
||||
ep_dict['action'] = torch.vstack(ep_dict['action']) * 180.0 / np.pi
|
||||
ep_dict['reward'] = torch.stack(ep_dict['reward'])
|
||||
|
||||
ep_dict["seed"] = torch.tensor([seed] * num_frames)
|
||||
ep_dict["episode_index"] = torch.tensor([episode_index] * num_frames)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
@@ -489,7 +495,8 @@ def record(
|
||||
observations_queue.close()
|
||||
break
|
||||
else:
|
||||
print('Waiting for ten seconds before starting the next recording session.....')
|
||||
print('Waiting for two seconds before starting the next recording session.....')
|
||||
busy_wait(2)
|
||||
|
||||
|
||||
num_episodes = episode_index
|
||||
@@ -579,10 +586,11 @@ def replay(env, episodes: list, fps: int | None = None, root="data", repo_id="le
|
||||
|
||||
dataset = LeRobotDataset(repo_id, root=root)
|
||||
items = dataset.hf_dataset.select_columns("action")
|
||||
seeds = dataset.hf_dataset.select_columns("seed")['seed']
|
||||
for episode in episodes:
|
||||
env.reset()
|
||||
from_idx = dataset.episode_data_index["from"][episode].item()
|
||||
to_idx = dataset.episode_data_index["to"][episode].item()
|
||||
env.reset(seed=seeds[from_idx].item())
|
||||
|
||||
logging.info("Replaying episode")
|
||||
say("Replaying episode", blocking=True)
|
||||
|
||||
Reference in New Issue
Block a user