Implement sync API for networking

This commit is contained in:
Kristóf Tóth 2019-07-31 17:08:47 +02:00
parent 78c3a8cf98
commit 911831fdb1
3 changed files with 80 additions and 5 deletions

View File

@ -3,6 +3,7 @@ from os.path import join
from secrets import token_urlsafe from secrets import token_urlsafe
from random import randint from random import randint
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from contextlib import suppress
import pytest import pytest
from tornado.ioloop import IOLoop from tornado.ioloop import IOLoop
@ -63,6 +64,26 @@ def test_messages():
] ]
def wait_until_subscriber_connects(listener, connector):
# Warning: you are better off without comprehending how this works
# Reference: ZMQ PUB-SUB slow joiner problem
# Wait until something can go through the connection
dummy = {'key': '-'}
while True:
listener.send_message(dummy)
with suppress(IOError):
if connector.recv_message(block=False) == dummy:
break
# Throw away leftover messages from last while loop
sentinel = {'key': '_'}
listener.send_message(sentinel)
while True:
with suppress(IOError):
if connector.recv_message(block=False) == sentinel:
break
def test_server_downlink(zmq_listener, zmq_connector, test_messages): def test_server_downlink(zmq_listener, zmq_connector, test_messages):
messages = [] messages = []
zmq_listener.register_callback(messages.append) zmq_listener.register_callback(messages.append)
@ -104,3 +125,28 @@ def test_connector_downlink_subscribe(zmq_listener, zmq_connector):
assert messages == key1_messages assert messages == key1_messages
assert all((msg not in messages for msg in key2_messages)) assert all((msg not in messages for msg in key2_messages))
def test_listener_sync_recv(zmq_listener, zmq_connector, test_messages):
for message in test_messages:
zmq_connector.send_message(message)
assert zmq_listener.recv_message() == message
def test_connector_sync_recv(zmq_listener, zmq_connector, test_messages):
zmq_connector.subscribe('')
wait_until_subscriber_connects(zmq_listener, zmq_connector)
for message in test_messages:
zmq_listener.send_message(message)
assert zmq_connector.recv_message() == message
def test_sync_recv_raises_if_callback_is_registered(zmq_listener, zmq_connector):
zmq_listener.register_callback(lambda msg: None)
zmq_connector.register_callback(lambda msg: None)
with pytest.raises(RuntimeError):
zmq_listener.recv_message()
with pytest.raises(RuntimeError):
zmq_connector.recv_message()

View File

@ -4,7 +4,11 @@ import zmq
from zmq.eventloop.zmqstream import ZMQStream from zmq.eventloop.zmqstream import ZMQStream
from .scope import Scope from .scope import Scope
from .serialization import serialize_tfw_msg, with_deserialize_tfw_msg from .serialization import (
serialize_tfw_msg,
deserialize_tfw_msg,
with_deserialize_tfw_msg
)
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -29,8 +33,20 @@ class ZMQDownlinkConnector:
self.keys.remove(key) self.keys.remove(key)
def register_callback(self, callback): def register_callback(self, callback):
self._on_recv_callback = callback if callback:
self._zmq_sub_stream.on_recv(with_deserialize_tfw_msg(self._on_recv)) self._on_recv_callback = callback
self._zmq_sub_stream.on_recv(with_deserialize_tfw_msg(self._on_recv))
else:
self._zmq_sub_stream.on_recv(None)
def recv_message(self, *, block=True):
if self._zmq_sub_stream.receiving():
raise RuntimeError('Synchronous recv() called while a callback is registered!')
flags = 0 if block else zmq.NOBLOCK
try:
return deserialize_tfw_msg(*self._zmq_sub_socket.recv_multipart(flags))
except zmq.ZMQError:
raise IOError("No data available to recv!")
def _on_recv(self, message): def _on_recv(self, message):
key = message['key'] key = message['key']

View File

@ -3,7 +3,11 @@ import logging
import zmq import zmq
from zmq.eventloop.zmqstream import ZMQStream from zmq.eventloop.zmqstream import ZMQStream
from .serialization import serialize_tfw_msg, with_deserialize_tfw_msg from .serialization import (
serialize_tfw_msg,
deserialize_tfw_msg,
with_deserialize_tfw_msg
)
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -17,9 +21,18 @@ class ZMQDownlinkListener:
LOG.debug('Pull socket bound to %s', bind_addr) LOG.debug('Pull socket bound to %s', bind_addr)
def register_callback(self, callback): def register_callback(self, callback):
callback = with_deserialize_tfw_msg(callback) callback = with_deserialize_tfw_msg(callback) if callback else None
self._zmq_pull_stream.on_recv(callback) self._zmq_pull_stream.on_recv(callback)
def recv_message(self, *, block=True):
if self._zmq_pull_stream.receiving():
raise RuntimeError('Synchronous recv() called while a callback is registered!')
flags = 0 if block else zmq.NOBLOCK
try:
return deserialize_tfw_msg(*self._zmq_pull_socket.recv_multipart(flags))
except zmq.ZMQError:
raise IOError("No data available to recv!")
def close(self): def close(self):
self._zmq_pull_stream.close() self._zmq_pull_stream.close()