diff --git a/tfw/internals/networking/test_networking.py b/tfw/internals/networking/test_networking.py index 9bd35af..d1d4261 100644 --- a/tfw/internals/networking/test_networking.py +++ b/tfw/internals/networking/test_networking.py @@ -3,6 +3,7 @@ from os.path import join from secrets import token_urlsafe from random import randint from tempfile import TemporaryDirectory +from contextlib import suppress import pytest from tornado.ioloop import IOLoop @@ -63,6 +64,26 @@ def test_messages(): ] +def wait_until_subscriber_connects(listener, connector): + # Warning: you are better off without comprehending how this works + # Reference: ZMQ PUB-SUB slow joiner problem + + # Wait until something can go through the connection + dummy = {'key': '-'} + while True: + listener.send_message(dummy) + with suppress(IOError): + if connector.recv_message(block=False) == dummy: + break + # Throw away leftover messages from last while loop + sentinel = {'key': '_'} + listener.send_message(sentinel) + while True: + with suppress(IOError): + if connector.recv_message(block=False) == sentinel: + break + + def test_server_downlink(zmq_listener, zmq_connector, test_messages): messages = [] zmq_listener.register_callback(messages.append) @@ -104,3 +125,28 @@ def test_connector_downlink_subscribe(zmq_listener, zmq_connector): assert messages == key1_messages assert all((msg not in messages for msg in key2_messages)) + + +def test_listener_sync_recv(zmq_listener, zmq_connector, test_messages): + for message in test_messages: + zmq_connector.send_message(message) + assert zmq_listener.recv_message() == message + + +def test_connector_sync_recv(zmq_listener, zmq_connector, test_messages): + zmq_connector.subscribe('') + wait_until_subscriber_connects(zmq_listener, zmq_connector) + for message in test_messages: + zmq_listener.send_message(message) + assert zmq_connector.recv_message() == message + + +def test_sync_recv_raises_if_callback_is_registered(zmq_listener, zmq_connector): + zmq_listener.register_callback(lambda msg: None) + zmq_connector.register_callback(lambda msg: None) + + with pytest.raises(RuntimeError): + zmq_listener.recv_message() + + with pytest.raises(RuntimeError): + zmq_connector.recv_message() diff --git a/tfw/internals/networking/zmq_connector.py b/tfw/internals/networking/zmq_connector.py index a701c99..7dfa77c 100644 --- a/tfw/internals/networking/zmq_connector.py +++ b/tfw/internals/networking/zmq_connector.py @@ -4,7 +4,11 @@ import zmq from zmq.eventloop.zmqstream import ZMQStream from .scope import Scope -from .serialization import serialize_tfw_msg, with_deserialize_tfw_msg +from .serialization import ( + serialize_tfw_msg, + deserialize_tfw_msg, + with_deserialize_tfw_msg +) LOG = logging.getLogger(__name__) @@ -29,8 +33,20 @@ class ZMQDownlinkConnector: self.keys.remove(key) def register_callback(self, callback): - self._on_recv_callback = callback - self._zmq_sub_stream.on_recv(with_deserialize_tfw_msg(self._on_recv)) + if callback: + self._on_recv_callback = callback + self._zmq_sub_stream.on_recv(with_deserialize_tfw_msg(self._on_recv)) + else: + self._zmq_sub_stream.on_recv(None) + + def recv_message(self, *, block=True): + if self._zmq_sub_stream.receiving(): + raise RuntimeError('Synchronous recv() called while a callback is registered!') + flags = 0 if block else zmq.NOBLOCK + try: + return deserialize_tfw_msg(*self._zmq_sub_socket.recv_multipart(flags)) + except zmq.ZMQError: + raise IOError("No data available to recv!") def _on_recv(self, message): key = message['key'] diff --git a/tfw/internals/networking/zmq_listener.py b/tfw/internals/networking/zmq_listener.py index a1e1313..d43814a 100644 --- a/tfw/internals/networking/zmq_listener.py +++ b/tfw/internals/networking/zmq_listener.py @@ -3,7 +3,11 @@ import logging import zmq from zmq.eventloop.zmqstream import ZMQStream -from .serialization import serialize_tfw_msg, with_deserialize_tfw_msg +from .serialization import ( + serialize_tfw_msg, + deserialize_tfw_msg, + with_deserialize_tfw_msg +) LOG = logging.getLogger(__name__) @@ -17,9 +21,18 @@ class ZMQDownlinkListener: LOG.debug('Pull socket bound to %s', bind_addr) def register_callback(self, callback): - callback = with_deserialize_tfw_msg(callback) + callback = with_deserialize_tfw_msg(callback) if callback else None self._zmq_pull_stream.on_recv(callback) + def recv_message(self, *, block=True): + if self._zmq_pull_stream.receiving(): + raise RuntimeError('Synchronous recv() called while a callback is registered!') + flags = 0 if block else zmq.NOBLOCK + try: + return deserialize_tfw_msg(*self._zmq_pull_socket.recv_multipart(flags)) + except zmq.ZMQError: + raise IOError("No data available to recv!") + def close(self): self._zmq_pull_stream.close()