add: server computes action, robot's daemon constantly reads it

This commit is contained in:
Francesco Capuano
2025-04-14 19:25:44 +02:00
parent fc107a2c6e
commit a9031ee1be
2 changed files with 57 additions and 85 deletions

View File

@@ -20,9 +20,12 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
self.policy = policy
# TODO: Add device specification for policy inference
self.observation = None
self.clients = []
# self.observation = None
self.observation = async_inference_pb2.Observation(
transfer_state=2,
data=np.array([1], dtype=np.float32).tobytes()
)
self.lock = threading.Lock()
# keeping a list of all observations received from the robot client
self.observations = []
@@ -43,12 +46,15 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
f"data size={len(observation.data)} bytes"
)
with self.lock:
self.observation = observation
self.observations.append(observation)
data = np.frombuffer(self.observation.data, dtype=np.float32)
data = np.frombuffer(
self.observation.data,
# observation data are stored as float32
dtype=np.float32
)
print(f"Current observation data: {data}")
return async_inference_pb2.Empty()
@@ -58,18 +64,8 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
client_id = context.peer()
print(f"Client {client_id} connected for action streaming")
# Keep track of this client for sending actions
with self.lock:
self.clients.append(context)
try:
# Keep the connection alive
while context.is_active():
time.sleep(0.1)
finally:
with self.lock:
if context in self.clients:
self.clients.remove(context)
yield self._generate_and_queue_action(self.observation)
return async_inference_pb2.Empty()
@@ -86,30 +82,22 @@ class PolicyServer(async_inference_pb2_grpc.AsyncInferenceServicer):
def _generate_and_queue_action(self, observation):
"""Generate an action based on the observation (dummy logic).
Mainly used for testing purposes"""
# Just create a random action as a response
action_data = np.random.rand(50).astype(np.float32).tobytes()
# Debinarize the observation data
data = np.frombuffer(
observation.data,
dtype=np.float32
)
# dummy transform on the observation data
action = (data * 1.4).sum()
# map action to bytes
action_data = np.array([action], dtype=np.float32).tobytes()
action = async_inference_pb2.Action(
transfer_state=observation.transfer_state,
data=action_data
)
# Send this action to all connected clients
dead_clients = []
for client_context in self.clients:
try:
if client_context.is_active():
client_context.send_initial_metadata([])
yield action
else:
dead_clients.append(client_context)
except:
dead_clients.append(client_context)
# Clean up dead clients, if any
for dead in dead_clients:
if dead in self.clients:
self.clients.remove(dead)
return action
def serve():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))