Use deque + Nearest neighbor

This commit is contained in:
Remi Cadene
2024-09-18 01:23:48 +02:00
parent 21de778377
commit b2e5f7fe2d

29
test.py
View File

@@ -1,28 +1,33 @@
import math
import threading
import time
from collections import deque
from threading import Thread
import numpy as np
class TemporalQueue:
def __init__(self):
self.items = []
self.timestamps = []
self.items = deque(maxlen=10)
self.timestamps = deque(maxlen=10)
def add(self, item, timestamp):
self.items.append(item)
self.timestamps.append(timestamp)
def get(self, timestamp=None):
if timestamp is None:
return self.items[-1], self.timestamps[-1]
def get_latest(self):
return self.items[-1], self.timestamps[-1]
# TODO(rcadene): implement nearest neighbor instead of hacky floor
for idx, t in list(enumerate(self.timestamps))[::-1]:
if math.floor(t) == math.floor(timestamp):
return self.items[idx], t
def get(self, timestamp):
timestamps = np.array(list(self.timestamps))
distances = np.abs(timestamps - timestamp)
nearest_idx = distances.argmin()
raise ValueError()
# print(float(distances[nearest_idx]))
if float(distances[nearest_idx]) > 1 / 5:
raise ValueError()
return self.items[nearest_idx], self.timestamps[nearest_idx]
def __len__(self):
return len(self.items)
@@ -46,7 +51,7 @@ class Policy:
def inference_loop(self):
prev_timestamp = None
while not self.stop_event.is_set():
last_observation, last_timestamp = self.obs_queue.get()
last_observation, last_timestamp = self.obs_queue.get_latest()
if prev_timestamp is not None and prev_timestamp == last_timestamp:
# in case inference ran faster than recording/adding a new observation in the queue