mirror of
				https://github.com/avatao-content/baseimage-tutorial-framework
				synced 2025-11-04 13:12:55 +00:00 
			
		
		
		
	Implement sync API for networking
This commit is contained in:
		@@ -3,6 +3,7 @@ from os.path import join
 | 
			
		||||
from secrets import token_urlsafe
 | 
			
		||||
from random import randint
 | 
			
		||||
from tempfile import TemporaryDirectory
 | 
			
		||||
from contextlib import suppress
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
from tornado.ioloop import IOLoop
 | 
			
		||||
@@ -63,6 +64,26 @@ def test_messages():
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def wait_until_subscriber_connects(listener, connector):
 | 
			
		||||
    # Warning: you are better off without comprehending how this works
 | 
			
		||||
    # Reference: ZMQ PUB-SUB slow joiner problem
 | 
			
		||||
 | 
			
		||||
    # Wait until something can go through the connection
 | 
			
		||||
    dummy = {'key': '-'}
 | 
			
		||||
    while True:
 | 
			
		||||
        listener.send_message(dummy)
 | 
			
		||||
        with suppress(IOError):
 | 
			
		||||
            if connector.recv_message(block=False) == dummy:
 | 
			
		||||
                break
 | 
			
		||||
    # Throw away leftover messages from last while loop
 | 
			
		||||
    sentinel = {'key': '_'}
 | 
			
		||||
    listener.send_message(sentinel)
 | 
			
		||||
    while True:
 | 
			
		||||
        with suppress(IOError):
 | 
			
		||||
            if connector.recv_message(block=False) == sentinel:
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_server_downlink(zmq_listener, zmq_connector, test_messages):
 | 
			
		||||
    messages = []
 | 
			
		||||
    zmq_listener.register_callback(messages.append)
 | 
			
		||||
@@ -104,3 +125,28 @@ def test_connector_downlink_subscribe(zmq_listener, zmq_connector):
 | 
			
		||||
 | 
			
		||||
    assert messages == key1_messages
 | 
			
		||||
    assert all((msg not in messages for msg in key2_messages))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_listener_sync_recv(zmq_listener, zmq_connector, test_messages):
 | 
			
		||||
    for message in test_messages:
 | 
			
		||||
        zmq_connector.send_message(message)
 | 
			
		||||
        assert zmq_listener.recv_message() == message
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_connector_sync_recv(zmq_listener, zmq_connector, test_messages):
 | 
			
		||||
    zmq_connector.subscribe('')
 | 
			
		||||
    wait_until_subscriber_connects(zmq_listener, zmq_connector)
 | 
			
		||||
    for message in test_messages:
 | 
			
		||||
        zmq_listener.send_message(message)
 | 
			
		||||
        assert zmq_connector.recv_message() == message
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_sync_recv_raises_if_callback_is_registered(zmq_listener, zmq_connector):
 | 
			
		||||
    zmq_listener.register_callback(lambda msg: None)
 | 
			
		||||
    zmq_connector.register_callback(lambda msg: None)
 | 
			
		||||
 | 
			
		||||
    with pytest.raises(RuntimeError):
 | 
			
		||||
        zmq_listener.recv_message()
 | 
			
		||||
 | 
			
		||||
    with pytest.raises(RuntimeError):
 | 
			
		||||
        zmq_connector.recv_message()
 | 
			
		||||
 
 | 
			
		||||
@@ -4,7 +4,11 @@ import zmq
 | 
			
		||||
from zmq.eventloop.zmqstream import ZMQStream
 | 
			
		||||
 | 
			
		||||
from .scope import Scope
 | 
			
		||||
from .serialization import serialize_tfw_msg, with_deserialize_tfw_msg
 | 
			
		||||
from .serialization import (
 | 
			
		||||
    serialize_tfw_msg,
 | 
			
		||||
    deserialize_tfw_msg,
 | 
			
		||||
    with_deserialize_tfw_msg
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
LOG = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
@@ -29,8 +33,20 @@ class ZMQDownlinkConnector:
 | 
			
		||||
            self.keys.remove(key)
 | 
			
		||||
 | 
			
		||||
    def register_callback(self, callback):
 | 
			
		||||
        self._on_recv_callback = callback
 | 
			
		||||
        self._zmq_sub_stream.on_recv(with_deserialize_tfw_msg(self._on_recv))
 | 
			
		||||
        if callback:
 | 
			
		||||
            self._on_recv_callback = callback
 | 
			
		||||
            self._zmq_sub_stream.on_recv(with_deserialize_tfw_msg(self._on_recv))
 | 
			
		||||
        else:
 | 
			
		||||
            self._zmq_sub_stream.on_recv(None)
 | 
			
		||||
 | 
			
		||||
    def recv_message(self, *, block=True):
 | 
			
		||||
        if self._zmq_sub_stream.receiving():
 | 
			
		||||
            raise RuntimeError('Synchronous recv() called while a callback is registered!')
 | 
			
		||||
        flags = 0 if block else zmq.NOBLOCK
 | 
			
		||||
        try:
 | 
			
		||||
            return deserialize_tfw_msg(*self._zmq_sub_socket.recv_multipart(flags))
 | 
			
		||||
        except zmq.ZMQError:
 | 
			
		||||
            raise IOError("No data available to recv!")
 | 
			
		||||
 | 
			
		||||
    def _on_recv(self, message):
 | 
			
		||||
        key = message['key']
 | 
			
		||||
 
 | 
			
		||||
@@ -3,7 +3,11 @@ import logging
 | 
			
		||||
import zmq
 | 
			
		||||
from zmq.eventloop.zmqstream import ZMQStream
 | 
			
		||||
 | 
			
		||||
from .serialization import serialize_tfw_msg, with_deserialize_tfw_msg
 | 
			
		||||
from .serialization import (
 | 
			
		||||
    serialize_tfw_msg,
 | 
			
		||||
    deserialize_tfw_msg,
 | 
			
		||||
    with_deserialize_tfw_msg
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
LOG = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
@@ -17,9 +21,18 @@ class ZMQDownlinkListener:
 | 
			
		||||
        LOG.debug('Pull socket bound to %s', bind_addr)
 | 
			
		||||
 | 
			
		||||
    def register_callback(self, callback):
 | 
			
		||||
        callback = with_deserialize_tfw_msg(callback)
 | 
			
		||||
        callback = with_deserialize_tfw_msg(callback) if callback else None
 | 
			
		||||
        self._zmq_pull_stream.on_recv(callback)
 | 
			
		||||
 | 
			
		||||
    def recv_message(self, *, block=True):
 | 
			
		||||
        if self._zmq_pull_stream.receiving():
 | 
			
		||||
            raise RuntimeError('Synchronous recv() called while a callback is registered!')
 | 
			
		||||
        flags = 0 if block else zmq.NOBLOCK
 | 
			
		||||
        try:
 | 
			
		||||
            return deserialize_tfw_msg(*self._zmq_pull_socket.recv_multipart(flags))
 | 
			
		||||
        except zmq.ZMQError:
 | 
			
		||||
            raise IOError("No data available to recv!")
 | 
			
		||||
 | 
			
		||||
    def close(self):
 | 
			
		||||
        self._zmq_pull_stream.close()
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user