mirror of
https://github.com/avatao-content/baseimage-tutorial-framework
synced 2024-11-15 02:37:17 +00:00
Implement sync API for networking
This commit is contained in:
parent
78c3a8cf98
commit
911831fdb1
@ -3,6 +3,7 @@ from os.path import join
|
|||||||
from secrets import token_urlsafe
|
from secrets import token_urlsafe
|
||||||
from random import randint
|
from random import randint
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
|
from contextlib import suppress
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from tornado.ioloop import IOLoop
|
from tornado.ioloop import IOLoop
|
||||||
@ -63,6 +64,26 @@ def test_messages():
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def wait_until_subscriber_connects(listener, connector):
|
||||||
|
# Warning: you are better off without comprehending how this works
|
||||||
|
# Reference: ZMQ PUB-SUB slow joiner problem
|
||||||
|
|
||||||
|
# Wait until something can go through the connection
|
||||||
|
dummy = {'key': '-'}
|
||||||
|
while True:
|
||||||
|
listener.send_message(dummy)
|
||||||
|
with suppress(IOError):
|
||||||
|
if connector.recv_message(block=False) == dummy:
|
||||||
|
break
|
||||||
|
# Throw away leftover messages from last while loop
|
||||||
|
sentinel = {'key': '_'}
|
||||||
|
listener.send_message(sentinel)
|
||||||
|
while True:
|
||||||
|
with suppress(IOError):
|
||||||
|
if connector.recv_message(block=False) == sentinel:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
def test_server_downlink(zmq_listener, zmq_connector, test_messages):
|
def test_server_downlink(zmq_listener, zmq_connector, test_messages):
|
||||||
messages = []
|
messages = []
|
||||||
zmq_listener.register_callback(messages.append)
|
zmq_listener.register_callback(messages.append)
|
||||||
@ -104,3 +125,28 @@ def test_connector_downlink_subscribe(zmq_listener, zmq_connector):
|
|||||||
|
|
||||||
assert messages == key1_messages
|
assert messages == key1_messages
|
||||||
assert all((msg not in messages for msg in key2_messages))
|
assert all((msg not in messages for msg in key2_messages))
|
||||||
|
|
||||||
|
|
||||||
|
def test_listener_sync_recv(zmq_listener, zmq_connector, test_messages):
|
||||||
|
for message in test_messages:
|
||||||
|
zmq_connector.send_message(message)
|
||||||
|
assert zmq_listener.recv_message() == message
|
||||||
|
|
||||||
|
|
||||||
|
def test_connector_sync_recv(zmq_listener, zmq_connector, test_messages):
|
||||||
|
zmq_connector.subscribe('')
|
||||||
|
wait_until_subscriber_connects(zmq_listener, zmq_connector)
|
||||||
|
for message in test_messages:
|
||||||
|
zmq_listener.send_message(message)
|
||||||
|
assert zmq_connector.recv_message() == message
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_recv_raises_if_callback_is_registered(zmq_listener, zmq_connector):
|
||||||
|
zmq_listener.register_callback(lambda msg: None)
|
||||||
|
zmq_connector.register_callback(lambda msg: None)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
zmq_listener.recv_message()
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
zmq_connector.recv_message()
|
||||||
|
@ -4,7 +4,11 @@ import zmq
|
|||||||
from zmq.eventloop.zmqstream import ZMQStream
|
from zmq.eventloop.zmqstream import ZMQStream
|
||||||
|
|
||||||
from .scope import Scope
|
from .scope import Scope
|
||||||
from .serialization import serialize_tfw_msg, with_deserialize_tfw_msg
|
from .serialization import (
|
||||||
|
serialize_tfw_msg,
|
||||||
|
deserialize_tfw_msg,
|
||||||
|
with_deserialize_tfw_msg
|
||||||
|
)
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -29,8 +33,20 @@ class ZMQDownlinkConnector:
|
|||||||
self.keys.remove(key)
|
self.keys.remove(key)
|
||||||
|
|
||||||
def register_callback(self, callback):
|
def register_callback(self, callback):
|
||||||
self._on_recv_callback = callback
|
if callback:
|
||||||
self._zmq_sub_stream.on_recv(with_deserialize_tfw_msg(self._on_recv))
|
self._on_recv_callback = callback
|
||||||
|
self._zmq_sub_stream.on_recv(with_deserialize_tfw_msg(self._on_recv))
|
||||||
|
else:
|
||||||
|
self._zmq_sub_stream.on_recv(None)
|
||||||
|
|
||||||
|
def recv_message(self, *, block=True):
|
||||||
|
if self._zmq_sub_stream.receiving():
|
||||||
|
raise RuntimeError('Synchronous recv() called while a callback is registered!')
|
||||||
|
flags = 0 if block else zmq.NOBLOCK
|
||||||
|
try:
|
||||||
|
return deserialize_tfw_msg(*self._zmq_sub_socket.recv_multipart(flags))
|
||||||
|
except zmq.ZMQError:
|
||||||
|
raise IOError("No data available to recv!")
|
||||||
|
|
||||||
def _on_recv(self, message):
|
def _on_recv(self, message):
|
||||||
key = message['key']
|
key = message['key']
|
||||||
|
@ -3,7 +3,11 @@ import logging
|
|||||||
import zmq
|
import zmq
|
||||||
from zmq.eventloop.zmqstream import ZMQStream
|
from zmq.eventloop.zmqstream import ZMQStream
|
||||||
|
|
||||||
from .serialization import serialize_tfw_msg, with_deserialize_tfw_msg
|
from .serialization import (
|
||||||
|
serialize_tfw_msg,
|
||||||
|
deserialize_tfw_msg,
|
||||||
|
with_deserialize_tfw_msg
|
||||||
|
)
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -17,9 +21,18 @@ class ZMQDownlinkListener:
|
|||||||
LOG.debug('Pull socket bound to %s', bind_addr)
|
LOG.debug('Pull socket bound to %s', bind_addr)
|
||||||
|
|
||||||
def register_callback(self, callback):
|
def register_callback(self, callback):
|
||||||
callback = with_deserialize_tfw_msg(callback)
|
callback = with_deserialize_tfw_msg(callback) if callback else None
|
||||||
self._zmq_pull_stream.on_recv(callback)
|
self._zmq_pull_stream.on_recv(callback)
|
||||||
|
|
||||||
|
def recv_message(self, *, block=True):
|
||||||
|
if self._zmq_pull_stream.receiving():
|
||||||
|
raise RuntimeError('Synchronous recv() called while a callback is registered!')
|
||||||
|
flags = 0 if block else zmq.NOBLOCK
|
||||||
|
try:
|
||||||
|
return deserialize_tfw_msg(*self._zmq_pull_socket.recv_multipart(flags))
|
||||||
|
except zmq.ZMQError:
|
||||||
|
raise IOError("No data available to recv!")
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self._zmq_pull_stream.close()
|
self._zmq_pull_stream.close()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user