mirror of
https://github.com/avatao-content/baseimage-tutorial-framework
synced 2024-11-14 23:07:16 +00:00
Implement sync API for networking
This commit is contained in:
parent
78c3a8cf98
commit
911831fdb1
@ -3,6 +3,7 @@ from os.path import join
|
||||
from secrets import token_urlsafe
|
||||
from random import randint
|
||||
from tempfile import TemporaryDirectory
|
||||
from contextlib import suppress
|
||||
|
||||
import pytest
|
||||
from tornado.ioloop import IOLoop
|
||||
@ -63,6 +64,26 @@ def test_messages():
|
||||
]
|
||||
|
||||
|
||||
def wait_until_subscriber_connects(listener, connector):
|
||||
# Warning: you are better off without comprehending how this works
|
||||
# Reference: ZMQ PUB-SUB slow joiner problem
|
||||
|
||||
# Wait until something can go through the connection
|
||||
dummy = {'key': '-'}
|
||||
while True:
|
||||
listener.send_message(dummy)
|
||||
with suppress(IOError):
|
||||
if connector.recv_message(block=False) == dummy:
|
||||
break
|
||||
# Throw away leftover messages from last while loop
|
||||
sentinel = {'key': '_'}
|
||||
listener.send_message(sentinel)
|
||||
while True:
|
||||
with suppress(IOError):
|
||||
if connector.recv_message(block=False) == sentinel:
|
||||
break
|
||||
|
||||
|
||||
def test_server_downlink(zmq_listener, zmq_connector, test_messages):
|
||||
messages = []
|
||||
zmq_listener.register_callback(messages.append)
|
||||
@ -104,3 +125,28 @@ def test_connector_downlink_subscribe(zmq_listener, zmq_connector):
|
||||
|
||||
assert messages == key1_messages
|
||||
assert all((msg not in messages for msg in key2_messages))
|
||||
|
||||
|
||||
def test_listener_sync_recv(zmq_listener, zmq_connector, test_messages):
|
||||
for message in test_messages:
|
||||
zmq_connector.send_message(message)
|
||||
assert zmq_listener.recv_message() == message
|
||||
|
||||
|
||||
def test_connector_sync_recv(zmq_listener, zmq_connector, test_messages):
|
||||
zmq_connector.subscribe('')
|
||||
wait_until_subscriber_connects(zmq_listener, zmq_connector)
|
||||
for message in test_messages:
|
||||
zmq_listener.send_message(message)
|
||||
assert zmq_connector.recv_message() == message
|
||||
|
||||
|
||||
def test_sync_recv_raises_if_callback_is_registered(zmq_listener, zmq_connector):
|
||||
zmq_listener.register_callback(lambda msg: None)
|
||||
zmq_connector.register_callback(lambda msg: None)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
zmq_listener.recv_message()
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
zmq_connector.recv_message()
|
||||
|
@ -4,7 +4,11 @@ import zmq
|
||||
from zmq.eventloop.zmqstream import ZMQStream
|
||||
|
||||
from .scope import Scope
|
||||
from .serialization import serialize_tfw_msg, with_deserialize_tfw_msg
|
||||
from .serialization import (
|
||||
serialize_tfw_msg,
|
||||
deserialize_tfw_msg,
|
||||
with_deserialize_tfw_msg
|
||||
)
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -29,8 +33,20 @@ class ZMQDownlinkConnector:
|
||||
self.keys.remove(key)
|
||||
|
||||
def register_callback(self, callback):
|
||||
self._on_recv_callback = callback
|
||||
self._zmq_sub_stream.on_recv(with_deserialize_tfw_msg(self._on_recv))
|
||||
if callback:
|
||||
self._on_recv_callback = callback
|
||||
self._zmq_sub_stream.on_recv(with_deserialize_tfw_msg(self._on_recv))
|
||||
else:
|
||||
self._zmq_sub_stream.on_recv(None)
|
||||
|
||||
def recv_message(self, *, block=True):
|
||||
if self._zmq_sub_stream.receiving():
|
||||
raise RuntimeError('Synchronous recv() called while a callback is registered!')
|
||||
flags = 0 if block else zmq.NOBLOCK
|
||||
try:
|
||||
return deserialize_tfw_msg(*self._zmq_sub_socket.recv_multipart(flags))
|
||||
except zmq.ZMQError:
|
||||
raise IOError("No data available to recv!")
|
||||
|
||||
def _on_recv(self, message):
|
||||
key = message['key']
|
||||
|
@ -3,7 +3,11 @@ import logging
|
||||
import zmq
|
||||
from zmq.eventloop.zmqstream import ZMQStream
|
||||
|
||||
from .serialization import serialize_tfw_msg, with_deserialize_tfw_msg
|
||||
from .serialization import (
|
||||
serialize_tfw_msg,
|
||||
deserialize_tfw_msg,
|
||||
with_deserialize_tfw_msg
|
||||
)
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -17,9 +21,18 @@ class ZMQDownlinkListener:
|
||||
LOG.debug('Pull socket bound to %s', bind_addr)
|
||||
|
||||
def register_callback(self, callback):
|
||||
callback = with_deserialize_tfw_msg(callback)
|
||||
callback = with_deserialize_tfw_msg(callback) if callback else None
|
||||
self._zmq_pull_stream.on_recv(callback)
|
||||
|
||||
def recv_message(self, *, block=True):
|
||||
if self._zmq_pull_stream.receiving():
|
||||
raise RuntimeError('Synchronous recv() called while a callback is registered!')
|
||||
flags = 0 if block else zmq.NOBLOCK
|
||||
try:
|
||||
return deserialize_tfw_msg(*self._zmq_pull_socket.recv_multipart(flags))
|
||||
except zmq.ZMQError:
|
||||
raise IOError("No data available to recv!")
|
||||
|
||||
def close(self):
|
||||
self._zmq_pull_stream.close()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user