cap DROID execution frequency (#282)
This commit is contained in:
@@ -6,7 +6,7 @@ import datetime
|
|||||||
import faulthandler
|
import faulthandler
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
|
import time
|
||||||
from moviepy.editor import ImageSequenceClip
|
from moviepy.editor import ImageSequenceClip
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from openpi_client import image_tools
|
from openpi_client import image_tools
|
||||||
@@ -19,6 +19,9 @@ import tyro
|
|||||||
|
|
||||||
faulthandler.enable()
|
faulthandler.enable()
|
||||||
|
|
||||||
|
# DROID data collection frequency -- we slow down execution to match this frequency
|
||||||
|
DROID_CONTROL_FREQUENCY = 15
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class Args:
|
class Args:
|
||||||
@@ -95,6 +98,7 @@ def main(args: Args):
|
|||||||
bar = tqdm.tqdm(range(args.max_timesteps))
|
bar = tqdm.tqdm(range(args.max_timesteps))
|
||||||
print("Running rollout... press Ctrl+C to stop early.")
|
print("Running rollout... press Ctrl+C to stop early.")
|
||||||
for t_step in bar:
|
for t_step in bar:
|
||||||
|
start_time = time.time()
|
||||||
try:
|
try:
|
||||||
# Get the current observation
|
# Get the current observation
|
||||||
curr_obs = _extract_observation(
|
curr_obs = _extract_observation(
|
||||||
@@ -145,6 +149,11 @@ def main(args: Args):
|
|||||||
action = np.clip(action, -1, 1)
|
action = np.clip(action, -1, 1)
|
||||||
|
|
||||||
env.step(action)
|
env.step(action)
|
||||||
|
|
||||||
|
# Sleep to match DROID data collection frequency
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
if elapsed_time < 1 / DROID_CONTROL_FREQUENCY:
|
||||||
|
time.sleep(1 / DROID_CONTROL_FREQUENCY - elapsed_time)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user