Use deque + Nearest neighbor
This commit is contained in:
29
test.py
29
test.py
@@ -1,28 +1,33 @@
|
||||
import math
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from threading import Thread
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TemporalQueue:
|
||||
def __init__(self):
|
||||
self.items = []
|
||||
self.timestamps = []
|
||||
self.items = deque(maxlen=10)
|
||||
self.timestamps = deque(maxlen=10)
|
||||
|
||||
def add(self, item, timestamp):
|
||||
self.items.append(item)
|
||||
self.timestamps.append(timestamp)
|
||||
|
||||
def get(self, timestamp=None):
|
||||
if timestamp is None:
|
||||
return self.items[-1], self.timestamps[-1]
|
||||
def get_latest(self):
|
||||
return self.items[-1], self.timestamps[-1]
|
||||
|
||||
# TODO(rcadene): implement nearest neighbor instead of hacky floor
|
||||
for idx, t in list(enumerate(self.timestamps))[::-1]:
|
||||
if math.floor(t) == math.floor(timestamp):
|
||||
return self.items[idx], t
|
||||
def get(self, timestamp):
|
||||
timestamps = np.array(list(self.timestamps))
|
||||
distances = np.abs(timestamps - timestamp)
|
||||
nearest_idx = distances.argmin()
|
||||
|
||||
raise ValueError()
|
||||
# print(float(distances[nearest_idx]))
|
||||
if float(distances[nearest_idx]) > 1 / 5:
|
||||
raise ValueError()
|
||||
|
||||
return self.items[nearest_idx], self.timestamps[nearest_idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.items)
|
||||
@@ -46,7 +51,7 @@ class Policy:
|
||||
def inference_loop(self):
|
||||
prev_timestamp = None
|
||||
while not self.stop_event.is_set():
|
||||
last_observation, last_timestamp = self.obs_queue.get()
|
||||
last_observation, last_timestamp = self.obs_queue.get_latest()
|
||||
|
||||
if prev_timestamp is not None and prev_timestamp == last_timestamp:
|
||||
# in case inference ran faster than recording/adding a new observation in the queue
|
||||
|
||||
Reference in New Issue
Block a user