Use deque + Nearest neighbor
This commit is contained in:
29
test.py
29
test.py
@@ -1,28 +1,33 @@
|
|||||||
import math
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
from collections import deque
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class TemporalQueue:
|
class TemporalQueue:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.items = []
|
self.items = deque(maxlen=10)
|
||||||
self.timestamps = []
|
self.timestamps = deque(maxlen=10)
|
||||||
|
|
||||||
def add(self, item, timestamp):
|
def add(self, item, timestamp):
|
||||||
self.items.append(item)
|
self.items.append(item)
|
||||||
self.timestamps.append(timestamp)
|
self.timestamps.append(timestamp)
|
||||||
|
|
||||||
def get(self, timestamp=None):
|
def get_latest(self):
|
||||||
if timestamp is None:
|
return self.items[-1], self.timestamps[-1]
|
||||||
return self.items[-1], self.timestamps[-1]
|
|
||||||
|
|
||||||
# TODO(rcadene): implement nearest neighbor instead of hacky floor
|
def get(self, timestamp):
|
||||||
for idx, t in list(enumerate(self.timestamps))[::-1]:
|
timestamps = np.array(list(self.timestamps))
|
||||||
if math.floor(t) == math.floor(timestamp):
|
distances = np.abs(timestamps - timestamp)
|
||||||
return self.items[idx], t
|
nearest_idx = distances.argmin()
|
||||||
|
|
||||||
raise ValueError()
|
# print(float(distances[nearest_idx]))
|
||||||
|
if float(distances[nearest_idx]) > 1 / 5:
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
|
return self.items[nearest_idx], self.timestamps[nearest_idx]
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.items)
|
return len(self.items)
|
||||||
@@ -46,7 +51,7 @@ class Policy:
|
|||||||
def inference_loop(self):
|
def inference_loop(self):
|
||||||
prev_timestamp = None
|
prev_timestamp = None
|
||||||
while not self.stop_event.is_set():
|
while not self.stop_event.is_set():
|
||||||
last_observation, last_timestamp = self.obs_queue.get()
|
last_observation, last_timestamp = self.obs_queue.get_latest()
|
||||||
|
|
||||||
if prev_timestamp is not None and prev_timestamp == last_timestamp:
|
if prev_timestamp is not None and prev_timestamp == last_timestamp:
|
||||||
# in case inference ran faster than recording/adding a new observation in the queue
|
# in case inference ran faster than recording/adding a new observation in the queue
|
||||||
|
|||||||
Reference in New Issue
Block a user