From b2e5f7fe2de5f2e997ddf7a9ffaf10a1d3dd9da8 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Wed, 18 Sep 2024 01:23:48 +0200 Subject: [PATCH] Use deque + Nearest neighbor --- test.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/test.py b/test.py index 6e11601f..c604b9fe 100644 --- a/test.py +++ b/test.py @@ -1,28 +1,33 @@ -import math import threading import time +from collections import deque from threading import Thread +import numpy as np + class TemporalQueue: def __init__(self): - self.items = [] - self.timestamps = [] + self.items = deque(maxlen=10) + self.timestamps = deque(maxlen=10) def add(self, item, timestamp): self.items.append(item) self.timestamps.append(timestamp) - def get(self, timestamp=None): - if timestamp is None: - return self.items[-1], self.timestamps[-1] + def get_latest(self): + return self.items[-1], self.timestamps[-1] - # TODO(rcadene): implement nearest neighbor instead of hacky floor - for idx, t in list(enumerate(self.timestamps))[::-1]: - if math.floor(t) == math.floor(timestamp): - return self.items[idx], t + def get(self, timestamp): + timestamps = np.array(list(self.timestamps)) + distances = np.abs(timestamps - timestamp) + nearest_idx = distances.argmin() - raise ValueError() + # print(float(distances[nearest_idx])) + if float(distances[nearest_idx]) > 1 / 5: + raise ValueError() + + return self.items[nearest_idx], self.timestamps[nearest_idx] def __len__(self): return len(self.items) @@ -46,7 +51,7 @@ class Policy: def inference_loop(self): prev_timestamp = None while not self.stop_event.is_set(): - last_observation, last_timestamp = self.obs_queue.get() + last_observation, last_timestamp = self.obs_queue.get_latest() if prev_timestamp is not None and prev_timestamp == last_timestamp: # in case inference ran faster than recording/adding a new observation in the queue