diff --git a/test.py b/test.py index 155b10c87..6e11601fd 100644 --- a/test.py +++ b/test.py @@ -34,23 +34,31 @@ class Policy: self.action_queue = TemporalQueue() self.thread = None + self.n_action = 2 + FPS = 10 # noqa: N806 + self.delta_timestamps = [i / FPS for i in range(self.n_action)] + def inference(self, observation): # TODO time.sleep(0.5) - return observation + return [observation] * self.n_action def inference_loop(self): - previous_timestamp = None + prev_timestamp = None while not self.stop_event.is_set(): - latest_observation, latest_timestamp = self.obs_queue.get() + last_observation, last_timestamp = self.obs_queue.get() - if previous_timestamp is not None and previous_timestamp == latest_timestamp: + if prev_timestamp is not None and prev_timestamp == last_timestamp: # in case inference ran faster than recording/adding a new observation in the queue time.sleep(0.1) - else: - predicted_action_sequence = self.inference(latest_observation) - self.action_queue.add(predicted_action_sequence, latest_timestamp) - previous_timestamp = latest_timestamp + continue + + pred_action_sequence = self.inference(last_observation) + + for action, delta_ts in zip(pred_action_sequence, self.delta_timestamps, strict=False): + self.action_queue.add(action, last_timestamp + delta_ts) + + prev_timestamp = last_timestamp def select_action( self,