# pylint: disable=redefined-outer-name from os.path import join from secrets import token_urlsafe from random import randint from tempfile import TemporaryDirectory from contextlib import suppress import pytest from tornado.ioloop import IOLoop from tfw.internals.networking import ZMQListener, ZMQConnector, Scope, Intent @pytest.fixture def _listener_and_connector(): with TemporaryDirectory() as tmpdir: down_sock = join(tmpdir, 'down') up_sock = join(tmpdir, 'up') server_downlink = f'ipc://{down_sock}' server_uplink = f'ipc://{up_sock}' listener = ZMQListener(server_downlink, server_uplink) connector = ZMQConnector(server_uplink, server_downlink) yield listener, connector listener.close() connector.close() @pytest.fixture def zmq_listener(_listener_and_connector): listener, _ = _listener_and_connector yield listener @pytest.fixture def zmq_connector(_listener_and_connector): _, connector = _listener_and_connector yield connector def run_ioloop_once(): # Hack: we have to wait for the messages to get through # the network stack of the OS while the IOLoop is waiting # for them via select/epoll/kqueue. # This is an inherent race condition, but solving this # problem properly would make the test code difficult # to understand, so we use this half measure. IOLoop.current().call_later(0.1, IOLoop.current().stop) IOLoop.current().start() @pytest.fixture def test_messages(): random_str = lambda: token_urlsafe(randint(4, 8)) yield [ { 'key': random_str(), random_str(): randint(8192, 16384), random_str(): random_str(), random_str(): { random_str(): random_str(), random_str(): {random_str(): random_str()} }, random_str(): [random_str(), random_str()] } for _ in range(randint(8, 16)) ] def wait_until_subscriber_connects(listener, connector): # Warning: you are better off without comprehending how this works # Reference: ZMQ PUB-SUB slow joiner problem connector.subscribe('-', '_') # Wait until something can go through the connection dummy = {'key': '-'} while True: listener.send_message(dummy) with suppress(IOError): if connector.recv_message(block=False) == dummy: break # Throw away leftover messages from last while loop sentinel = {'key': '_'} listener.send_message(sentinel) while True: with suppress(IOError): if connector.recv_message(block=False) == sentinel: break connector.unsubscribe('-', '_') def test_server_downlink(zmq_listener, zmq_connector, test_messages): messages = [] zmq_listener.register_callback(messages.append) for message in test_messages: zmq_connector.send_message(message) run_ioloop_once() assert messages == test_messages def test_connector_default_scope_is_zmq(zmq_listener, zmq_connector): messages = [] zmq_listener.register_callback(messages.append) zmq_connector.send_message({'key': 'cica'}) run_ioloop_once() assert messages[0]['scope'] == Scope.ZMQ.value def test_connector_preserves_scope(zmq_listener, zmq_connector): messages = [] zmq_listener.register_callback(messages.append) zmq_connector.send_message({'key': 'cica', 'scope': Scope.WEBSOCKET.value}) run_ioloop_once() assert messages[0]['scope'] == Scope.WEBSOCKET.value def test_connector_scope_overrides_message_scope(zmq_listener, zmq_connector): messages = [] zmq_listener.register_callback(messages.append) zmq_connector.send_message( {'key': 'cica', 'scope': Scope.WEBSOCKET.value}, scope=Scope.ZMQ ) run_ioloop_once() assert messages[0]['scope'] == Scope.ZMQ.value def test_connector_adds_intent(zmq_listener, zmq_connector): messages = [] zmq_listener.register_callback(messages.append) zmq_connector.send_message( {'key': 'cica'}, intent=Intent.EVENT ) run_ioloop_once() assert messages[0]['intent'] == Intent.EVENT.value def test_connector_preserves_intent(zmq_listener, zmq_connector): messages = [] zmq_listener.register_callback(messages.append) zmq_connector.send_message({'key': 'cica', 'intent': Intent.EVENT.value}) run_ioloop_once() assert messages[0]['intent'] == Intent.EVENT.value def test_server_uplink(zmq_listener, zmq_connector, test_messages): zmq_connector.subscribe('') wait_until_subscriber_connects(zmq_listener, zmq_connector) messages = [] zmq_connector.register_callback(messages.append) for message in test_messages: zmq_listener.send_message(message) run_ioloop_once() assert messages == test_messages def test_connector_downlink_subscribe(zmq_listener, zmq_connector): key1_messages = [{'key': '1', 'data': i} for i in range(randint(128, 256))] key2_messages = [{'key': '2', 'data': i} for i in range(randint(128, 256))] all_messages = key1_messages + key2_messages zmq_connector.subscribe('1') wait_until_subscriber_connects(zmq_listener, zmq_connector) messages = [] zmq_connector.register_callback(messages.append) for message in all_messages: zmq_listener.send_message(message) run_ioloop_once() assert messages == key1_messages assert all((msg not in messages for msg in key2_messages)) def test_listener_sync_recv(zmq_listener, zmq_connector, test_messages): for message in test_messages: zmq_connector.send_message(message) assert zmq_listener.recv_message() == message def test_connector_sync_recv(zmq_listener, zmq_connector, test_messages): zmq_connector.subscribe('') wait_until_subscriber_connects(zmq_listener, zmq_connector) for message in test_messages: zmq_listener.send_message(message) assert zmq_connector.recv_message() == message def test_sync_recv_raises_if_callback_is_registered(zmq_listener, zmq_connector): zmq_listener.register_callback(lambda msg: None) zmq_connector.register_callback(lambda msg: None) with pytest.raises(RuntimeError): zmq_listener.recv_message() with pytest.raises(RuntimeError): zmq_connector.recv_message()