added visualize images option

This commit is contained in:
Michel Aractingi
2024-10-17 10:53:38 +02:00
parent b45490874a
commit 1adcb3fdec

View File

@@ -289,7 +289,7 @@ def record(
tags=None, tags=None,
num_image_writers_per_camera=4, num_image_writers_per_camera=4,
force_override=False, force_override=False,
visualization_mode='viewer', visualize_images=0,
**kwargs **kwargs
): ):
@@ -323,7 +323,6 @@ def record(
exit_early = False exit_early = False
rerecord_episode = False rerecord_episode = False
stop_recording = False stop_recording = False
# Only import pynput if not in a headless environment # Only import pynput if not in a headless environment
if not is_headless(): if not is_headless():
from pynput import keyboard from pynput import keyboard
@@ -363,7 +362,7 @@ def record(
command_queue = multiprocessing.Queue(1000) command_queue = multiprocessing.Queue(1000)
stop_reading_leader = multiprocessing.Value('i', 0) stop_reading_leader = multiprocessing.Value('i', 0)
read_leader = multiprocessing.Process(target=read_commands_from_leader, args=(robot, command_queue, fps, axis_directions, offsets, stop_reading_leader)) read_leader = multiprocessing.Process(target=read_commands_from_leader, args=(robot, command_queue, fps, axis_directions, offsets, stop_reading_leader))
if not is_headless() and visualization_mode=='observations': if not is_headless() and visualize_images:
observations_queue = multiprocessing.Queue(1000) observations_queue = multiprocessing.Queue(1000)
show_images = multiprocessing.Process(target=show_image_observations, args=(observations_queue, )) show_images = multiprocessing.Process(target=show_image_observations, args=(observations_queue, ))
show_images.start() show_images.start()
@@ -393,7 +392,7 @@ def record(
save_image, observation[key].squeeze(0), str_key, frame_index, episode_index, videos_dir) save_image, observation[key].squeeze(0), str_key, frame_index, episode_index, videos_dir)
] ]
if not is_headless() and visualization_mode=='observations': if not is_headless() and visualize_images:
observations_queue.put(observation) observations_queue.put(observation)
state_obs = [] state_obs = []
@@ -494,7 +493,7 @@ def record(
concurrent.futures.as_completed(futures), total=len(futures), desc="Writting images" concurrent.futures.as_completed(futures), total=len(futures), desc="Writting images"
): ):
pass pass
if not is_headless() and visualization_mode=='rgb_array': if not is_headless() and visualize_images:
show_images.terminate() show_images.terminate()
observations_queue.close() observations_queue.close()
break break
@@ -692,11 +691,10 @@ if __name__ == "__main__":
help="By default, data recording is resumed. When set to 1, delete the local directory and start data recording from scratch.", help="By default, data recording is resumed. When set to 1, delete the local directory and start data recording from scratch.",
) )
parser_record.add_argument( parser_record.add_argument(
"--visualization-mode", "--visualize-images",
type=str, type=int,
default='viewer', default=0,
choices=['viewer', 'observations'], help="Visualize image observations with opencv.",
help="By default, data recording is resumed. When set to 1, delete the local directory and start data recording from scratch.",
) )
parser_replay = subparsers.add_parser("replay", parents=[base_parser]) parser_replay = subparsers.add_parser("replay", parents=[base_parser])
@@ -731,7 +729,6 @@ if __name__ == "__main__":
# make gym env # make gym env
env_cfg = init_hydra_config(env_config_path) env_cfg = init_hydra_config(env_config_path)
env_cfg.env.gym.render_mode = 'human' if args.visualization_mode=='viewer' else 'rgb_array'
env_fn = lambda: make_env(env_cfg, n_envs=1) env_fn = lambda: make_env(env_cfg, n_envs=1)
# make robot # make robot