Implement sync API for networking

This commit is contained in:
Kristóf Tóth 2019-07-31 17:08:47 +02:00
parent 78c3a8cf98
commit 911831fdb1
3 changed files with 80 additions and 5 deletions

View File

@ -3,6 +3,7 @@ from os.path import join
from secrets import token_urlsafe
from random import randint
from tempfile import TemporaryDirectory
from contextlib import suppress
import pytest
from tornado.ioloop import IOLoop
@ -63,6 +64,26 @@ def test_messages():
]
def wait_until_subscriber_connects(listener, connector):
# Warning: you are better off without comprehending how this works
# Reference: ZMQ PUB-SUB slow joiner problem
# Wait until something can go through the connection
dummy = {'key': '-'}
while True:
listener.send_message(dummy)
with suppress(IOError):
if connector.recv_message(block=False) == dummy:
break
# Throw away leftover messages from last while loop
sentinel = {'key': '_'}
listener.send_message(sentinel)
while True:
with suppress(IOError):
if connector.recv_message(block=False) == sentinel:
break
def test_server_downlink(zmq_listener, zmq_connector, test_messages):
messages = []
zmq_listener.register_callback(messages.append)
@ -104,3 +125,28 @@ def test_connector_downlink_subscribe(zmq_listener, zmq_connector):
assert messages == key1_messages
assert all((msg not in messages for msg in key2_messages))
def test_listener_sync_recv(zmq_listener, zmq_connector, test_messages):
for message in test_messages:
zmq_connector.send_message(message)
assert zmq_listener.recv_message() == message
def test_connector_sync_recv(zmq_listener, zmq_connector, test_messages):
zmq_connector.subscribe('')
wait_until_subscriber_connects(zmq_listener, zmq_connector)
for message in test_messages:
zmq_listener.send_message(message)
assert zmq_connector.recv_message() == message
def test_sync_recv_raises_if_callback_is_registered(zmq_listener, zmq_connector):
zmq_listener.register_callback(lambda msg: None)
zmq_connector.register_callback(lambda msg: None)
with pytest.raises(RuntimeError):
zmq_listener.recv_message()
with pytest.raises(RuntimeError):
zmq_connector.recv_message()

View File

@ -4,7 +4,11 @@ import zmq
from zmq.eventloop.zmqstream import ZMQStream
from .scope import Scope
from .serialization import serialize_tfw_msg, with_deserialize_tfw_msg
from .serialization import (
serialize_tfw_msg,
deserialize_tfw_msg,
with_deserialize_tfw_msg
)
LOG = logging.getLogger(__name__)
@ -29,8 +33,20 @@ class ZMQDownlinkConnector:
self.keys.remove(key)
def register_callback(self, callback):
self._on_recv_callback = callback
self._zmq_sub_stream.on_recv(with_deserialize_tfw_msg(self._on_recv))
if callback:
self._on_recv_callback = callback
self._zmq_sub_stream.on_recv(with_deserialize_tfw_msg(self._on_recv))
else:
self._zmq_sub_stream.on_recv(None)
def recv_message(self, *, block=True):
if self._zmq_sub_stream.receiving():
raise RuntimeError('Synchronous recv() called while a callback is registered!')
flags = 0 if block else zmq.NOBLOCK
try:
return deserialize_tfw_msg(*self._zmq_sub_socket.recv_multipart(flags))
except zmq.ZMQError:
raise IOError("No data available to recv!")
def _on_recv(self, message):
key = message['key']

View File

@ -3,7 +3,11 @@ import logging
import zmq
from zmq.eventloop.zmqstream import ZMQStream
from .serialization import serialize_tfw_msg, with_deserialize_tfw_msg
from .serialization import (
serialize_tfw_msg,
deserialize_tfw_msg,
with_deserialize_tfw_msg
)
LOG = logging.getLogger(__name__)
@ -17,9 +21,18 @@ class ZMQDownlinkListener:
LOG.debug('Pull socket bound to %s', bind_addr)
def register_callback(self, callback):
callback = with_deserialize_tfw_msg(callback)
callback = with_deserialize_tfw_msg(callback) if callback else None
self._zmq_pull_stream.on_recv(callback)
def recv_message(self, *, block=True):
if self._zmq_pull_stream.receiving():
raise RuntimeError('Synchronous recv() called while a callback is registered!')
flags = 0 if block else zmq.NOBLOCK
try:
return deserialize_tfw_msg(*self._zmq_pull_socket.recv_multipart(flags))
except zmq.ZMQError:
raise IOError("No data available to recv!")
def close(self):
self._zmq_pull_stream.close()