mirror of
				https://github.com/avatao-content/baseimage-tutorial-framework
				synced 2025-11-04 01:42:55 +00:00 
			
		
		
		
	Implement sync API for networking
This commit is contained in:
		@@ -3,6 +3,7 @@ from os.path import join
 | 
				
			|||||||
from secrets import token_urlsafe
 | 
					from secrets import token_urlsafe
 | 
				
			||||||
from random import randint
 | 
					from random import randint
 | 
				
			||||||
from tempfile import TemporaryDirectory
 | 
					from tempfile import TemporaryDirectory
 | 
				
			||||||
 | 
					from contextlib import suppress
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import pytest
 | 
					import pytest
 | 
				
			||||||
from tornado.ioloop import IOLoop
 | 
					from tornado.ioloop import IOLoop
 | 
				
			||||||
@@ -63,6 +64,26 @@ def test_messages():
 | 
				
			|||||||
    ]
 | 
					    ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def wait_until_subscriber_connects(listener, connector):
 | 
				
			||||||
 | 
					    # Warning: you are better off without comprehending how this works
 | 
				
			||||||
 | 
					    # Reference: ZMQ PUB-SUB slow joiner problem
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Wait until something can go through the connection
 | 
				
			||||||
 | 
					    dummy = {'key': '-'}
 | 
				
			||||||
 | 
					    while True:
 | 
				
			||||||
 | 
					        listener.send_message(dummy)
 | 
				
			||||||
 | 
					        with suppress(IOError):
 | 
				
			||||||
 | 
					            if connector.recv_message(block=False) == dummy:
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					    # Throw away leftover messages from last while loop
 | 
				
			||||||
 | 
					    sentinel = {'key': '_'}
 | 
				
			||||||
 | 
					    listener.send_message(sentinel)
 | 
				
			||||||
 | 
					    while True:
 | 
				
			||||||
 | 
					        with suppress(IOError):
 | 
				
			||||||
 | 
					            if connector.recv_message(block=False) == sentinel:
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_server_downlink(zmq_listener, zmq_connector, test_messages):
 | 
					def test_server_downlink(zmq_listener, zmq_connector, test_messages):
 | 
				
			||||||
    messages = []
 | 
					    messages = []
 | 
				
			||||||
    zmq_listener.register_callback(messages.append)
 | 
					    zmq_listener.register_callback(messages.append)
 | 
				
			||||||
@@ -104,3 +125,28 @@ def test_connector_downlink_subscribe(zmq_listener, zmq_connector):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    assert messages == key1_messages
 | 
					    assert messages == key1_messages
 | 
				
			||||||
    assert all((msg not in messages for msg in key2_messages))
 | 
					    assert all((msg not in messages for msg in key2_messages))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_listener_sync_recv(zmq_listener, zmq_connector, test_messages):
 | 
				
			||||||
 | 
					    for message in test_messages:
 | 
				
			||||||
 | 
					        zmq_connector.send_message(message)
 | 
				
			||||||
 | 
					        assert zmq_listener.recv_message() == message
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_connector_sync_recv(zmq_listener, zmq_connector, test_messages):
 | 
				
			||||||
 | 
					    zmq_connector.subscribe('')
 | 
				
			||||||
 | 
					    wait_until_subscriber_connects(zmq_listener, zmq_connector)
 | 
				
			||||||
 | 
					    for message in test_messages:
 | 
				
			||||||
 | 
					        zmq_listener.send_message(message)
 | 
				
			||||||
 | 
					        assert zmq_connector.recv_message() == message
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_sync_recv_raises_if_callback_is_registered(zmq_listener, zmq_connector):
 | 
				
			||||||
 | 
					    zmq_listener.register_callback(lambda msg: None)
 | 
				
			||||||
 | 
					    zmq_connector.register_callback(lambda msg: None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with pytest.raises(RuntimeError):
 | 
				
			||||||
 | 
					        zmq_listener.recv_message()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with pytest.raises(RuntimeError):
 | 
				
			||||||
 | 
					        zmq_connector.recv_message()
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4,7 +4,11 @@ import zmq
 | 
				
			|||||||
from zmq.eventloop.zmqstream import ZMQStream
 | 
					from zmq.eventloop.zmqstream import ZMQStream
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .scope import Scope
 | 
					from .scope import Scope
 | 
				
			||||||
from .serialization import serialize_tfw_msg, with_deserialize_tfw_msg
 | 
					from .serialization import (
 | 
				
			||||||
 | 
					    serialize_tfw_msg,
 | 
				
			||||||
 | 
					    deserialize_tfw_msg,
 | 
				
			||||||
 | 
					    with_deserialize_tfw_msg
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
LOG = logging.getLogger(__name__)
 | 
					LOG = logging.getLogger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -29,8 +33,20 @@ class ZMQDownlinkConnector:
 | 
				
			|||||||
            self.keys.remove(key)
 | 
					            self.keys.remove(key)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def register_callback(self, callback):
 | 
					    def register_callback(self, callback):
 | 
				
			||||||
 | 
					        if callback:
 | 
				
			||||||
            self._on_recv_callback = callback
 | 
					            self._on_recv_callback = callback
 | 
				
			||||||
            self._zmq_sub_stream.on_recv(with_deserialize_tfw_msg(self._on_recv))
 | 
					            self._zmq_sub_stream.on_recv(with_deserialize_tfw_msg(self._on_recv))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self._zmq_sub_stream.on_recv(None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def recv_message(self, *, block=True):
 | 
				
			||||||
 | 
					        if self._zmq_sub_stream.receiving():
 | 
				
			||||||
 | 
					            raise RuntimeError('Synchronous recv() called while a callback is registered!')
 | 
				
			||||||
 | 
					        flags = 0 if block else zmq.NOBLOCK
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            return deserialize_tfw_msg(*self._zmq_sub_socket.recv_multipart(flags))
 | 
				
			||||||
 | 
					        except zmq.ZMQError:
 | 
				
			||||||
 | 
					            raise IOError("No data available to recv!")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _on_recv(self, message):
 | 
					    def _on_recv(self, message):
 | 
				
			||||||
        key = message['key']
 | 
					        key = message['key']
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,7 +3,11 @@ import logging
 | 
				
			|||||||
import zmq
 | 
					import zmq
 | 
				
			||||||
from zmq.eventloop.zmqstream import ZMQStream
 | 
					from zmq.eventloop.zmqstream import ZMQStream
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .serialization import serialize_tfw_msg, with_deserialize_tfw_msg
 | 
					from .serialization import (
 | 
				
			||||||
 | 
					    serialize_tfw_msg,
 | 
				
			||||||
 | 
					    deserialize_tfw_msg,
 | 
				
			||||||
 | 
					    with_deserialize_tfw_msg
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
LOG = logging.getLogger(__name__)
 | 
					LOG = logging.getLogger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -17,9 +21,18 @@ class ZMQDownlinkListener:
 | 
				
			|||||||
        LOG.debug('Pull socket bound to %s', bind_addr)
 | 
					        LOG.debug('Pull socket bound to %s', bind_addr)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def register_callback(self, callback):
 | 
					    def register_callback(self, callback):
 | 
				
			||||||
        callback = with_deserialize_tfw_msg(callback)
 | 
					        callback = with_deserialize_tfw_msg(callback) if callback else None
 | 
				
			||||||
        self._zmq_pull_stream.on_recv(callback)
 | 
					        self._zmq_pull_stream.on_recv(callback)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def recv_message(self, *, block=True):
 | 
				
			||||||
 | 
					        if self._zmq_pull_stream.receiving():
 | 
				
			||||||
 | 
					            raise RuntimeError('Synchronous recv() called while a callback is registered!')
 | 
				
			||||||
 | 
					        flags = 0 if block else zmq.NOBLOCK
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            return deserialize_tfw_msg(*self._zmq_pull_socket.recv_multipart(flags))
 | 
				
			||||||
 | 
					        except zmq.ZMQError:
 | 
				
			||||||
 | 
					            raise IOError("No data available to recv!")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def close(self):
 | 
					    def close(self):
 | 
				
			||||||
        self._zmq_pull_stream.close()
 | 
					        self._zmq_pull_stream.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user