mirror of
				https://github.com/avatao-content/baseimage-tutorial-framework
				synced 2025-11-04 07:42:54 +00:00 
			
		
		
		
	Simplify package structure
This commit is contained in:
		
							
								
								
									
										0
									
								
								tfw/internals/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								tfw/internals/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										35
									
								
								tfw/internals/callback_mixin.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								tfw/internals/callback_mixin.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,35 @@
 | 
			
		||||
from functools import partial
 | 
			
		||||
 | 
			
		||||
from .lazy import lazy_property
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CallbackMixin:
 | 
			
		||||
    @lazy_property
 | 
			
		||||
    def _callbacks(self):
 | 
			
		||||
        # pylint: disable=no-self-use
 | 
			
		||||
        return []
 | 
			
		||||
 | 
			
		||||
    def subscribe_callback(self, callback, *args, **kwargs):
 | 
			
		||||
        """
 | 
			
		||||
        Subscribe a callable to invoke once an event is triggered.
 | 
			
		||||
        :param callback: callable to be executed on events
 | 
			
		||||
        :param args: arguments passed to callable
 | 
			
		||||
        :param kwargs: kwargs passed to callable
 | 
			
		||||
        """
 | 
			
		||||
        fun = partial(callback, *args, **kwargs)
 | 
			
		||||
        self._callbacks.append(fun)
 | 
			
		||||
 | 
			
		||||
    def subscribe_callbacks(self, *callbacks):
 | 
			
		||||
        """
 | 
			
		||||
        Subscribe a list of callbacks to incoke once an event is triggered.
 | 
			
		||||
        :param callbacks: callbacks to be subscribed
 | 
			
		||||
        """
 | 
			
		||||
        for callback in callbacks:
 | 
			
		||||
            self.subscribe_callback(callback)
 | 
			
		||||
 | 
			
		||||
    def unsubscribe_callback(self, callback):
 | 
			
		||||
        self._callbacks.remove(callback)
 | 
			
		||||
 | 
			
		||||
    def _execute_callbacks(self, *args, **kwargs):
 | 
			
		||||
        for callback in self._callbacks:
 | 
			
		||||
            callback(*args, **kwargs)
 | 
			
		||||
							
								
								
									
										107
									
								
								tfw/internals/crypto.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										107
									
								
								tfw/internals/crypto.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,107 @@
 | 
			
		||||
from functools import wraps
 | 
			
		||||
from base64 import b64encode, b64decode
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
from hashlib import md5
 | 
			
		||||
from os import urandom, chmod
 | 
			
		||||
from os.path import exists
 | 
			
		||||
from stat import S_IRUSR, S_IWUSR, S_IXUSR
 | 
			
		||||
 | 
			
		||||
from cryptography.hazmat.backends import default_backend
 | 
			
		||||
from cryptography.hazmat.primitives.hashes import SHA256
 | 
			
		||||
from cryptography.hazmat.primitives.hmac import HMAC as _HMAC
 | 
			
		||||
from cryptography.exceptions import InvalidSignature
 | 
			
		||||
 | 
			
		||||
from tfw.internals.networking import message_bytes
 | 
			
		||||
from tfw.internals.lazy import lazy_property
 | 
			
		||||
from tfw.config import TFWENV
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def message_checksum(message):
 | 
			
		||||
    return md5(message_bytes(message)).hexdigest()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def sign_message(key, message):
 | 
			
		||||
    message.pop('scope', None)
 | 
			
		||||
    message.pop('signature', None)
 | 
			
		||||
    signature = message_signature(key, message)
 | 
			
		||||
    message['signature'] = b64encode(signature).decode()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def message_signature(key, message):
 | 
			
		||||
    return HMAC(key, message_bytes(message)).signature
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def verify_message(key, message):
 | 
			
		||||
    message.pop('scope', None)
 | 
			
		||||
    message = deepcopy(message)
 | 
			
		||||
    try:
 | 
			
		||||
        signature_b64 = message.pop('signature')
 | 
			
		||||
        signature = b64decode(signature_b64)
 | 
			
		||||
        actual_signature = message_signature(key, message)
 | 
			
		||||
        return signature == actual_signature
 | 
			
		||||
    except KeyError:
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class KeyManager:
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.keyfile = TFWENV.AUTH_KEY
 | 
			
		||||
        if not exists(self.keyfile):
 | 
			
		||||
            self._init_auth_key()
 | 
			
		||||
 | 
			
		||||
    @lazy_property
 | 
			
		||||
    def auth_key(self):
 | 
			
		||||
        with open(self.keyfile, 'rb') as ifile:
 | 
			
		||||
            return ifile.read()
 | 
			
		||||
 | 
			
		||||
    def _init_auth_key(self):
 | 
			
		||||
        key = self.generate_key()
 | 
			
		||||
        with open(self.keyfile, 'wb') as ofile:
 | 
			
		||||
            ofile.write(key)
 | 
			
		||||
        self._chmod_700_keyfile()
 | 
			
		||||
        return key
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def generate_key():
 | 
			
		||||
        return urandom(32)
 | 
			
		||||
 | 
			
		||||
    def _chmod_700_keyfile(self):
 | 
			
		||||
        chmod(self.keyfile, S_IRUSR | S_IWUSR | S_IXUSR)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HMAC:
 | 
			
		||||
    def __init__(self, key, message):
 | 
			
		||||
        self.key = key
 | 
			
		||||
        self.message = message
 | 
			
		||||
        self._hmac = _HMAC(
 | 
			
		||||
            key=key,
 | 
			
		||||
            algorithm=SHA256(),
 | 
			
		||||
            backend=default_backend()
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _reload_if_finalized(f):
 | 
			
		||||
        # pylint: disable=no-self-argument,not-callable
 | 
			
		||||
        @wraps(f)
 | 
			
		||||
        def wrapped(instance, *args, **kwargs):
 | 
			
		||||
            if getattr(instance, '_finalized', False):
 | 
			
		||||
                instance.__init__(instance.key, instance.message)
 | 
			
		||||
            ret_val = f(instance, *args, **kwargs)
 | 
			
		||||
            setattr(instance, '_finalized', True)
 | 
			
		||||
            return ret_val
 | 
			
		||||
        return wrapped
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    @_reload_if_finalized
 | 
			
		||||
    def signature(self):
 | 
			
		||||
        self._hmac.update(self.message)
 | 
			
		||||
        signature = self._hmac.finalize()
 | 
			
		||||
        return signature
 | 
			
		||||
 | 
			
		||||
    @_reload_if_finalized
 | 
			
		||||
    def verify(self, signature):
 | 
			
		||||
        self._hmac.update(self.message)
 | 
			
		||||
        try:
 | 
			
		||||
            self._hmac.verify(signature)
 | 
			
		||||
            return True
 | 
			
		||||
        except InvalidSignature:
 | 
			
		||||
            return False
 | 
			
		||||
							
								
								
									
										3
									
								
								tfw/internals/event_handling/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								tfw/internals/event_handling/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,3 @@
 | 
			
		||||
from .event_handler_factory_base import EventHandlerFactoryBase
 | 
			
		||||
from .event_handler import EventHandler
 | 
			
		||||
from .fsm_aware_event_handler import FSMAwareEventHandler
 | 
			
		||||
							
								
								
									
										27
									
								
								tfw/internals/event_handling/event_handler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								tfw/internals/event_handling/event_handler.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,27 @@
 | 
			
		||||
class EventHandler:
 | 
			
		||||
    _instances = set()
 | 
			
		||||
 | 
			
		||||
    def __init__(self, server_connector):
 | 
			
		||||
        type(self)._instances.add(self)
 | 
			
		||||
        self.server_connector = server_connector
 | 
			
		||||
 | 
			
		||||
    def start(self):
 | 
			
		||||
        self.server_connector.register_callback(self._event_callback)
 | 
			
		||||
 | 
			
		||||
    def _event_callback(self, message):
 | 
			
		||||
        self.handle_event(message, self.server_connector)
 | 
			
		||||
 | 
			
		||||
    def handle_event(self, message, server_connector):
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def stop_all_instances(cls):
 | 
			
		||||
        for instance in cls._instances:
 | 
			
		||||
            instance.stop()
 | 
			
		||||
 | 
			
		||||
    def stop(self):
 | 
			
		||||
        self.server_connector.close()
 | 
			
		||||
        self.cleanup()
 | 
			
		||||
 | 
			
		||||
    def cleanup(self):
 | 
			
		||||
        pass
 | 
			
		||||
							
								
								
									
										68
									
								
								tfw/internals/event_handling/event_handler_factory_base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										68
									
								
								tfw/internals/event_handling/event_handler_factory_base.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,68 @@
 | 
			
		||||
from contextlib import suppress
 | 
			
		||||
 | 
			
		||||
from .event_handler import EventHandler
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EventHandlerFactoryBase:
 | 
			
		||||
    def build(self, handler_stub, *, keys=None, event_handler_type=EventHandler):
 | 
			
		||||
        builder = EventHandlerBuilder(handler_stub, keys, event_handler_type)
 | 
			
		||||
        server_connector = self._build_server_connector()
 | 
			
		||||
        event_handler = builder.build(server_connector)
 | 
			
		||||
        handler_stub.server_connector = server_connector
 | 
			
		||||
        with suppress(AttributeError):
 | 
			
		||||
            handler_stub.start()
 | 
			
		||||
        event_handler.start()
 | 
			
		||||
        return event_handler
 | 
			
		||||
 | 
			
		||||
    def _build_server_connector(self):
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EventHandlerBuilder:
 | 
			
		||||
    def __init__(self, event_handler, supplied_keys, event_handler_type):
 | 
			
		||||
        self._analyzer = HandlerStubAnalyzer(event_handler, supplied_keys)
 | 
			
		||||
        self._event_handler_type = event_handler_type
 | 
			
		||||
 | 
			
		||||
    def build(self, server_connector):
 | 
			
		||||
        event_handler = self._event_handler_type(server_connector)
 | 
			
		||||
        server_connector.subscribe(*self._try_get_keys(event_handler))
 | 
			
		||||
        event_handler.handle_event = self._analyzer.handle_event
 | 
			
		||||
        with suppress(AttributeError):
 | 
			
		||||
            event_handler.cleanup = self._analyzer.cleanup
 | 
			
		||||
        return event_handler
 | 
			
		||||
 | 
			
		||||
    def _try_get_keys(self, event_handler):
 | 
			
		||||
        try:
 | 
			
		||||
            return self._analyzer.keys
 | 
			
		||||
        except ValueError:
 | 
			
		||||
            with suppress(AttributeError):
 | 
			
		||||
                return event_handler.keys
 | 
			
		||||
            raise
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HandlerStubAnalyzer:
 | 
			
		||||
    def __init__(self, event_handler, supplied_keys):
 | 
			
		||||
        self._event_handler = event_handler
 | 
			
		||||
        self._supplied_keys = supplied_keys
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def keys(self):
 | 
			
		||||
        if self._supplied_keys is None:
 | 
			
		||||
            try:
 | 
			
		||||
                return self._event_handler.keys
 | 
			
		||||
            except AttributeError:
 | 
			
		||||
                raise ValueError('No keys supplied!')
 | 
			
		||||
        return self._supplied_keys
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def handle_event(self):
 | 
			
		||||
        try:
 | 
			
		||||
            return self._event_handler.handle_event
 | 
			
		||||
        except AttributeError:
 | 
			
		||||
            if callable(self._event_handler):
 | 
			
		||||
                return self._event_handler
 | 
			
		||||
            raise ValueError('Object must implement handle_event or be a callable!')
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def cleanup(self):
 | 
			
		||||
        return self._event_handler.cleanup
 | 
			
		||||
							
								
								
									
										37
									
								
								tfw/internals/event_handling/fsm_aware.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								tfw/internals/event_handling/fsm_aware.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,37 @@
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
from tfw.internals.crypto import KeyManager, verify_message
 | 
			
		||||
 | 
			
		||||
LOG = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FSMAware:
 | 
			
		||||
    keys = ['fsm_update']
 | 
			
		||||
    """
 | 
			
		||||
    Base class for stuff that has to be aware of the framework FSM.
 | 
			
		||||
    This is done by processing 'fsm_update' messages.
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.fsm_state = None
 | 
			
		||||
        self.fsm_in_accepted_state = False
 | 
			
		||||
        self.fsm_event_log = []
 | 
			
		||||
        self._auth_key = KeyManager().auth_key
 | 
			
		||||
 | 
			
		||||
    def process_message(self, message):
 | 
			
		||||
        if message['key'] == 'fsm_update':
 | 
			
		||||
            if verify_message(self._auth_key, message):
 | 
			
		||||
                self._handle_fsm_update(message)
 | 
			
		||||
 | 
			
		||||
    def _handle_fsm_update(self, message):
 | 
			
		||||
        try:
 | 
			
		||||
            new_state = message['current_state']
 | 
			
		||||
            if self.fsm_state != new_state:
 | 
			
		||||
                self.handle_fsm_step(message)
 | 
			
		||||
            self.fsm_state = new_state
 | 
			
		||||
            self.fsm_in_accepted_state = message['in_accepted_state']
 | 
			
		||||
            self.fsm_event_log.append(message)
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            LOG.error('Invalid fsm_update message received!')
 | 
			
		||||
 | 
			
		||||
    def handle_fsm_step(self, message):
 | 
			
		||||
        pass
 | 
			
		||||
							
								
								
									
										19
									
								
								tfw/internals/event_handling/fsm_aware_event_handler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								tfw/internals/event_handling/fsm_aware_event_handler.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,19 @@
 | 
			
		||||
from .event_handler import EventHandler
 | 
			
		||||
from .fsm_aware import FSMAware
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FSMAwareEventHandler(EventHandler, FSMAware):
 | 
			
		||||
    # pylint: disable=abstract-method
 | 
			
		||||
    """
 | 
			
		||||
    Abstract base class for EventHandlers which automatically
 | 
			
		||||
    keep track of the state of the TFW FSM.
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, server_connector):
 | 
			
		||||
        EventHandler.__init__(self, server_connector)
 | 
			
		||||
        FSMAware.__init__(self)
 | 
			
		||||
 | 
			
		||||
    def _event_callback(self, message):
 | 
			
		||||
        self.process_message(message)
 | 
			
		||||
 | 
			
		||||
    def handle_fsm_step(self, message):
 | 
			
		||||
        self.handle_event(message, self.server_connector)
 | 
			
		||||
							
								
								
									
										190
									
								
								tfw/internals/event_handling/test_event_handler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										190
									
								
								tfw/internals/event_handling/test_event_handler.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,190 @@
 | 
			
		||||
# pylint: disable=redefined-outer-name,attribute-defined-outside-init
 | 
			
		||||
from secrets import token_urlsafe
 | 
			
		||||
from random import randint
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
from .event_handler_factory_base import EventHandlerFactoryBase
 | 
			
		||||
from .event_handler import EventHandler
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MockEventHandlerFactory(EventHandlerFactoryBase):
 | 
			
		||||
    def _build_server_connector(self):
 | 
			
		||||
        return MockServerConnector()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MockServerConnector:
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.keys = []
 | 
			
		||||
        self._on_message = None
 | 
			
		||||
 | 
			
		||||
    def simulate_message(self, message):
 | 
			
		||||
        self._on_message(message)
 | 
			
		||||
 | 
			
		||||
    def register_callback(self, callback):
 | 
			
		||||
        self._on_message = callback
 | 
			
		||||
 | 
			
		||||
    def subscribe(self, *keys):
 | 
			
		||||
        self.keys.extend(keys)
 | 
			
		||||
 | 
			
		||||
    def unsubscribe(self, *keys):
 | 
			
		||||
        for key in keys:
 | 
			
		||||
            self.keys.remove(key)
 | 
			
		||||
 | 
			
		||||
    def send_message(self, message, scope=None):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def close(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MockEventHandlerStub:
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.server_connector = None
 | 
			
		||||
        self.last_message = None
 | 
			
		||||
        self.cleaned_up = False
 | 
			
		||||
        self.started = False
 | 
			
		||||
 | 
			
		||||
    def start(self):
 | 
			
		||||
        self.started = True
 | 
			
		||||
 | 
			
		||||
    def cleanup(self):
 | 
			
		||||
        self.cleaned_up = True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MockEventHandler(MockEventHandlerStub):
 | 
			
		||||
    # pylint: disable=unused-argument
 | 
			
		||||
    def handle_event(self, message, server_connector):
 | 
			
		||||
        self.last_message = message
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MockCallable(MockEventHandlerStub):
 | 
			
		||||
    def __call__(self, message, server_connector):
 | 
			
		||||
        self.last_message = message
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
def test_msg():
 | 
			
		||||
    yield token_urlsafe(randint(16, 64))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
def test_keys():
 | 
			
		||||
    yield [
 | 
			
		||||
        token_urlsafe(randint(2, 8))
 | 
			
		||||
        for _ in range(randint(16, 32))
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_build_from_object(test_keys, test_msg):
 | 
			
		||||
    mock_eh = MockEventHandlerStub()
 | 
			
		||||
    def handle_event(message, server_connector):
 | 
			
		||||
        raise RuntimeError(message, server_connector.keys)
 | 
			
		||||
    mock_eh.handle_event = handle_event
 | 
			
		||||
 | 
			
		||||
    assert not mock_eh.started
 | 
			
		||||
    eh = MockEventHandlerFactory().build(mock_eh, keys=test_keys)
 | 
			
		||||
 | 
			
		||||
    assert mock_eh.started
 | 
			
		||||
    assert mock_eh.server_connector is eh.server_connector
 | 
			
		||||
    with pytest.raises(RuntimeError) as err:
 | 
			
		||||
        eh.server_connector.simulate_message(test_msg)
 | 
			
		||||
        msg, keys = err.args
 | 
			
		||||
        assert msg == test_msg
 | 
			
		||||
        assert keys == test_keys
 | 
			
		||||
    assert not mock_eh.cleaned_up
 | 
			
		||||
    eh.stop()
 | 
			
		||||
    assert mock_eh.cleaned_up
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_build_from_object_with_keys(test_keys, test_msg):
 | 
			
		||||
    mock_eh = MockEventHandler()
 | 
			
		||||
    mock_eh.keys = test_keys
 | 
			
		||||
 | 
			
		||||
    assert not mock_eh.started
 | 
			
		||||
    eh = MockEventHandlerFactory().build(mock_eh)
 | 
			
		||||
 | 
			
		||||
    assert mock_eh.server_connector.keys == test_keys
 | 
			
		||||
    assert eh.server_connector is mock_eh.server_connector
 | 
			
		||||
    assert mock_eh.started
 | 
			
		||||
    assert not mock_eh.last_message
 | 
			
		||||
    eh.server_connector.simulate_message(test_msg)
 | 
			
		||||
    assert mock_eh.last_message == test_msg
 | 
			
		||||
    assert not mock_eh.cleaned_up
 | 
			
		||||
    EventHandler.stop_all_instances()
 | 
			
		||||
    assert mock_eh.cleaned_up
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_build_from_simple_object(test_keys, test_msg):
 | 
			
		||||
    class SimpleMockEventHandler:
 | 
			
		||||
        # pylint: disable=no-self-use
 | 
			
		||||
        def handle_event(self, message, server_connector):
 | 
			
		||||
            raise RuntimeError(message, server_connector)
 | 
			
		||||
 | 
			
		||||
    mock_eh = SimpleMockEventHandler()
 | 
			
		||||
    eh = MockEventHandlerFactory().build(mock_eh, keys=test_keys)
 | 
			
		||||
 | 
			
		||||
    with pytest.raises(RuntimeError) as err:
 | 
			
		||||
        eh.server_connector.simulate_message(test_msg)
 | 
			
		||||
        msg, keys = err.args
 | 
			
		||||
        assert msg == test_msg
 | 
			
		||||
        assert keys == test_keys
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_build_from_callable(test_keys, test_msg):
 | 
			
		||||
    mock_eh = MockCallable()
 | 
			
		||||
 | 
			
		||||
    assert not mock_eh.started
 | 
			
		||||
    eh = MockEventHandlerFactory().build(mock_eh, keys=test_keys)
 | 
			
		||||
 | 
			
		||||
    assert mock_eh.started
 | 
			
		||||
    assert mock_eh.server_connector is eh.server_connector
 | 
			
		||||
    assert eh.server_connector.keys == test_keys
 | 
			
		||||
    assert not mock_eh.last_message
 | 
			
		||||
    eh.server_connector.simulate_message(test_msg)
 | 
			
		||||
    assert mock_eh.last_message == test_msg
 | 
			
		||||
    assert not mock_eh.cleaned_up
 | 
			
		||||
    eh.stop()
 | 
			
		||||
    assert mock_eh.cleaned_up
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_build_from_function(test_keys, test_msg):
 | 
			
		||||
    def some_function(message, server_connector):
 | 
			
		||||
        raise RuntimeError(message, server_connector.keys)
 | 
			
		||||
    eh = MockEventHandlerFactory().build(some_function, keys=test_keys)
 | 
			
		||||
 | 
			
		||||
    assert eh.server_connector.keys == test_keys
 | 
			
		||||
    with pytest.raises(RuntimeError) as err:
 | 
			
		||||
        eh.server_connector.simulate_message(test_msg)
 | 
			
		||||
        msg, keys = err.args
 | 
			
		||||
        assert msg == test_msg
 | 
			
		||||
        assert keys == test_keys
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_build_from_lambda(test_keys, test_msg):
 | 
			
		||||
    def assert_messages_equal(msg):
 | 
			
		||||
        assert msg == test_msg
 | 
			
		||||
    fun = lambda msg, sc: assert_messages_equal(msg)
 | 
			
		||||
    eh = MockEventHandlerFactory().build(fun, keys=test_keys)
 | 
			
		||||
    eh.server_connector.simulate_message(test_msg)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_build_raises_if_no_key(test_keys):
 | 
			
		||||
    eh = MockEventHandler()
 | 
			
		||||
    with pytest.raises(ValueError):
 | 
			
		||||
        MockEventHandlerFactory().build(eh)
 | 
			
		||||
 | 
			
		||||
    def handle_event(*_):
 | 
			
		||||
        pass
 | 
			
		||||
    with pytest.raises(ValueError):
 | 
			
		||||
        MockEventHandlerFactory().build(handle_event)
 | 
			
		||||
 | 
			
		||||
    with pytest.raises(ValueError):
 | 
			
		||||
        MockEventHandlerFactory().build(lambda msg, sc: None)
 | 
			
		||||
 | 
			
		||||
    WithKeysEventHandler = EventHandler
 | 
			
		||||
    WithKeysEventHandler.keys = test_keys
 | 
			
		||||
    MockEventHandlerFactory().build(eh, event_handler_type=WithKeysEventHandler)
 | 
			
		||||
 | 
			
		||||
    eh.keys = test_keys
 | 
			
		||||
    MockEventHandlerFactory().build(eh)
 | 
			
		||||
							
								
								
									
										6
									
								
								tfw/internals/inotify/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								tfw/internals/inotify/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,6 @@
 | 
			
		||||
from .inotify import InotifyObserver
 | 
			
		||||
from .inotify import (
 | 
			
		||||
    InotifyFileCreatedEvent, InotifyFileModifiedEvent, InotifyFileMovedEvent,
 | 
			
		||||
    InotifyFileDeletedEvent, InotifyDirCreatedEvent, InotifyDirModifiedEvent,
 | 
			
		||||
    InotifyDirMovedEvent, InotifyDirDeletedEvent
 | 
			
		||||
)
 | 
			
		||||
							
								
								
									
										189
									
								
								tfw/internals/inotify/inotify.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										189
									
								
								tfw/internals/inotify/inotify.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,189 @@
 | 
			
		||||
# pylint: disable=too-few-public-methods
 | 
			
		||||
 | 
			
		||||
from typing import Iterable
 | 
			
		||||
from time import time
 | 
			
		||||
from os.path import abspath, dirname, isdir
 | 
			
		||||
 | 
			
		||||
from watchdog.observers import Observer
 | 
			
		||||
from watchdog.events import FileSystemMovedEvent, PatternMatchingEventHandler
 | 
			
		||||
from watchdog.events import (
 | 
			
		||||
    FileCreatedEvent, FileModifiedEvent, FileMovedEvent, FileDeletedEvent,
 | 
			
		||||
    DirCreatedEvent, DirModifiedEvent, DirMovedEvent, DirDeletedEvent
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InotifyEvent:
 | 
			
		||||
    def __init__(self, src_path):
 | 
			
		||||
        self.date = time()
 | 
			
		||||
        self.src_path = src_path
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return self.__repr__()
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return f'{self.__class__.__name__}({self.src_path})'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InotifyMovedEvent(InotifyEvent):
 | 
			
		||||
    def __init__(self, src_path, dest_path):
 | 
			
		||||
        self.dest_path = dest_path
 | 
			
		||||
        super().__init__(src_path)
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return f'{self.__class__.__name__}({self.src_path}, {self.dest_path})'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InotifyFileCreatedEvent(InotifyEvent):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InotifyFileModifiedEvent(InotifyEvent):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InotifyFileMovedEvent(InotifyMovedEvent):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InotifyFileDeletedEvent(InotifyEvent):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InotifyDirCreatedEvent(InotifyEvent):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InotifyDirModifiedEvent(InotifyEvent):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InotifyDirMovedEvent(InotifyMovedEvent):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InotifyDirDeletedEvent(InotifyEvent):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InotifyObserver:
 | 
			
		||||
    def __init__(self, path, patterns=None, exclude=None, recursive=False):
 | 
			
		||||
        self._files = []
 | 
			
		||||
        self._paths = path
 | 
			
		||||
        self._patterns = patterns or []
 | 
			
		||||
        self._exclude = exclude
 | 
			
		||||
        self._recursive = recursive
 | 
			
		||||
        self._observer = Observer()
 | 
			
		||||
        self._reset()
 | 
			
		||||
 | 
			
		||||
    def _reset(self):
 | 
			
		||||
        if isinstance(self._paths, str):
 | 
			
		||||
            self._paths = [self._paths]
 | 
			
		||||
        if isinstance(self._paths, Iterable):
 | 
			
		||||
            self._extract_files_from_paths()
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError('Expected one or more string paths.')
 | 
			
		||||
 | 
			
		||||
        patterns = self._files+self.patterns
 | 
			
		||||
        handler = PatternMatchingEventHandler(patterns if patterns else None, self.exclude)
 | 
			
		||||
        handler.on_any_event = self._dispatch_event
 | 
			
		||||
        self._observer.unschedule_all()
 | 
			
		||||
        for path in self.paths:
 | 
			
		||||
            self._observer.schedule(handler, path, self._recursive)
 | 
			
		||||
 | 
			
		||||
    def _extract_files_from_paths(self):
 | 
			
		||||
        files, paths = [], []
 | 
			
		||||
        for path in self._paths:
 | 
			
		||||
            path = abspath(path)
 | 
			
		||||
            if isdir(path):
 | 
			
		||||
                paths.append(path)
 | 
			
		||||
            else:
 | 
			
		||||
                paths.append(dirname(path))
 | 
			
		||||
                files.append(path)
 | 
			
		||||
        self._files, self._paths = files, paths
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def paths(self):
 | 
			
		||||
        return self._paths
 | 
			
		||||
 | 
			
		||||
    @paths.setter
 | 
			
		||||
    def paths(self, paths):
 | 
			
		||||
        self._paths = paths
 | 
			
		||||
        self._reset()
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def patterns(self):
 | 
			
		||||
        return self._patterns
 | 
			
		||||
 | 
			
		||||
    @patterns.setter
 | 
			
		||||
    def patterns(self, patterns):
 | 
			
		||||
        self._patterns = patterns or []
 | 
			
		||||
        self._reset()
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def exclude(self):
 | 
			
		||||
        return self._exclude
 | 
			
		||||
 | 
			
		||||
    @exclude.setter
 | 
			
		||||
    def exclude(self, exclude):
 | 
			
		||||
        self._exclude = exclude
 | 
			
		||||
        self._reset()
 | 
			
		||||
 | 
			
		||||
    def start(self):
 | 
			
		||||
        self._observer.start()
 | 
			
		||||
 | 
			
		||||
    def stop(self):
 | 
			
		||||
        self._observer.stop()
 | 
			
		||||
        self._observer.join()
 | 
			
		||||
 | 
			
		||||
    def _dispatch_event(self, event):
 | 
			
		||||
        event_to_action = {
 | 
			
		||||
            InotifyFileCreatedEvent  : self.on_created,
 | 
			
		||||
            InotifyFileModifiedEvent : self.on_modified,
 | 
			
		||||
            InotifyFileMovedEvent    : self.on_moved,
 | 
			
		||||
            InotifyFileDeletedEvent  : self.on_deleted,
 | 
			
		||||
            InotifyDirCreatedEvent   : self.on_created,
 | 
			
		||||
            InotifyDirModifiedEvent  : self.on_modified,
 | 
			
		||||
            InotifyDirMovedEvent     : self.on_moved,
 | 
			
		||||
            InotifyDirDeletedEvent   : self.on_deleted
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        event = self._transform_event(event)
 | 
			
		||||
        self.on_any_event(event)
 | 
			
		||||
        event_to_action[type(event)](event)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _transform_event(event):
 | 
			
		||||
        watchdog_to_inotify = {
 | 
			
		||||
            FileCreatedEvent  : InotifyFileCreatedEvent,
 | 
			
		||||
            FileModifiedEvent : InotifyFileModifiedEvent,
 | 
			
		||||
            FileMovedEvent    : InotifyFileMovedEvent,
 | 
			
		||||
            FileDeletedEvent  : InotifyFileDeletedEvent,
 | 
			
		||||
            DirCreatedEvent   : InotifyDirCreatedEvent,
 | 
			
		||||
            DirModifiedEvent  : InotifyDirModifiedEvent,
 | 
			
		||||
            DirMovedEvent     : InotifyDirMovedEvent,
 | 
			
		||||
            DirDeletedEvent   : InotifyDirDeletedEvent
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            cls = watchdog_to_inotify[type(event)]
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            raise NameError('Watchdog API returned an unknown event.')
 | 
			
		||||
 | 
			
		||||
        if isinstance(event, FileSystemMovedEvent):
 | 
			
		||||
            return cls(event.src_path, event.dest_path)
 | 
			
		||||
        return cls(event.src_path)
 | 
			
		||||
 | 
			
		||||
    def on_any_event(self, event):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def on_created(self, event):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def on_modified(self, event):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def on_moved(self, event):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def on_deleted(self, event):
 | 
			
		||||
        pass
 | 
			
		||||
							
								
								
									
										179
									
								
								tfw/internals/inotify/test_inotify.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										179
									
								
								tfw/internals/inotify/test_inotify.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,179 @@
 | 
			
		||||
# pylint: disable=redefined-outer-name
 | 
			
		||||
 | 
			
		||||
from queue import Empty, Queue
 | 
			
		||||
from secrets import token_urlsafe
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from shutil import rmtree
 | 
			
		||||
from os.path import join
 | 
			
		||||
from os import mkdir, remove, rename
 | 
			
		||||
from tempfile import TemporaryDirectory
 | 
			
		||||
from contextlib import suppress
 | 
			
		||||
 | 
			
		||||
import watchdog
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
from .inotify import InotifyObserver
 | 
			
		||||
from .inotify import (
 | 
			
		||||
    InotifyFileCreatedEvent, InotifyFileModifiedEvent, InotifyFileMovedEvent,
 | 
			
		||||
    InotifyFileDeletedEvent, InotifyDirCreatedEvent, InotifyDirModifiedEvent,
 | 
			
		||||
    InotifyDirMovedEvent, InotifyDirDeletedEvent
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
with suppress(AttributeError):
 | 
			
		||||
    watchdog.observers.inotify_buffer.InotifyBuffer.delay = 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InotifyContext:
 | 
			
		||||
    def __init__(self, workdir, subdir, subfile, observer):
 | 
			
		||||
        self.missing_events = 0
 | 
			
		||||
        self.workdir = workdir
 | 
			
		||||
        self.subdir = subdir
 | 
			
		||||
        self.subfile = subfile
 | 
			
		||||
        self.observer = observer
 | 
			
		||||
 | 
			
		||||
        self.event_to_queue = {
 | 
			
		||||
            InotifyFileCreatedEvent  : self.observer.create_queue,
 | 
			
		||||
            InotifyFileModifiedEvent : self.observer.modify_queue,
 | 
			
		||||
            InotifyFileMovedEvent    : self.observer.move_queue,
 | 
			
		||||
            InotifyFileDeletedEvent  : self.observer.delete_queue,
 | 
			
		||||
            InotifyDirCreatedEvent   : self.observer.create_queue,
 | 
			
		||||
            InotifyDirModifiedEvent  : self.observer.modify_queue,
 | 
			
		||||
            InotifyDirMovedEvent     : self.observer.move_queue,
 | 
			
		||||
            InotifyDirDeletedEvent   : self.observer.delete_queue
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    def create_random_file(self, dirname, extension):
 | 
			
		||||
        filename = self.join(f'{dirname}/{generate_name()}{extension}')
 | 
			
		||||
        Path(filename).touch()
 | 
			
		||||
        return filename
 | 
			
		||||
 | 
			
		||||
    def create_random_folder(self, basepath):
 | 
			
		||||
        dirname = self.join(f'{basepath}/{generate_name()}')
 | 
			
		||||
        mkdir(dirname)
 | 
			
		||||
        return dirname
 | 
			
		||||
 | 
			
		||||
    def join(self, path):
 | 
			
		||||
        return join(self.workdir, path)
 | 
			
		||||
 | 
			
		||||
    def check_event(self, event_type, path):
 | 
			
		||||
        self.missing_events += 1
 | 
			
		||||
        event = self.event_to_queue[event_type].get(timeout=0.1)
 | 
			
		||||
        assert isinstance(event, event_type)
 | 
			
		||||
        assert event.src_path == path
 | 
			
		||||
        return event
 | 
			
		||||
 | 
			
		||||
    def check_empty(self, event_type):
 | 
			
		||||
        with pytest.raises(Empty):
 | 
			
		||||
            self.event_to_queue[event_type].get(timeout=0.1)
 | 
			
		||||
 | 
			
		||||
    def check_any(self):
 | 
			
		||||
        attrs = self.observer.__dict__.values()
 | 
			
		||||
        total = sum([q.qsize() for q in attrs if isinstance(q, Queue)])
 | 
			
		||||
        return total+self.missing_events == len(self.observer.any_list)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InotifyTestObserver(InotifyObserver):
 | 
			
		||||
    def __init__(self, paths, patterns=None, exclude=None, recursive=False):
 | 
			
		||||
        self.any_list = []
 | 
			
		||||
        self.create_queue, self.modify_queue, self.move_queue, self.delete_queue = [Queue() for _ in range(4)]
 | 
			
		||||
        super().__init__(paths, patterns, exclude, recursive)
 | 
			
		||||
 | 
			
		||||
    def on_any_event(self, event):
 | 
			
		||||
        self.any_list.append(event)
 | 
			
		||||
 | 
			
		||||
    def on_created(self, event):
 | 
			
		||||
        self.create_queue.put(event)
 | 
			
		||||
 | 
			
		||||
    def on_modified(self, event):
 | 
			
		||||
        self.modify_queue.put(event)
 | 
			
		||||
 | 
			
		||||
    def on_moved(self, event):
 | 
			
		||||
        self.move_queue.put(event)
 | 
			
		||||
 | 
			
		||||
    def on_deleted(self, event):
 | 
			
		||||
        self.delete_queue.put(event)
 | 
			
		||||
 | 
			
		||||
def generate_name():
 | 
			
		||||
    return token_urlsafe(16)
 | 
			
		||||
 | 
			
		||||
@pytest.fixture()
 | 
			
		||||
def context():
 | 
			
		||||
    with TemporaryDirectory() as workdir:
 | 
			
		||||
        subdir = join(workdir, generate_name())
 | 
			
		||||
        subfile = join(subdir, generate_name()+'.txt')
 | 
			
		||||
        mkdir(subdir)
 | 
			
		||||
        Path(subfile).touch()
 | 
			
		||||
        monitor = InotifyTestObserver(workdir, recursive=True)
 | 
			
		||||
        monitor.start()
 | 
			
		||||
        yield InotifyContext(workdir, subdir, subfile, monitor)
 | 
			
		||||
 | 
			
		||||
def test_create(context):
 | 
			
		||||
    newfile = context.create_random_file(context.workdir, '.txt')
 | 
			
		||||
    context.check_event(InotifyFileCreatedEvent, newfile)
 | 
			
		||||
    newdir = context.create_random_folder(context.workdir)
 | 
			
		||||
    context.check_event(InotifyDirCreatedEvent, newdir)
 | 
			
		||||
    assert context.check_any()
 | 
			
		||||
 | 
			
		||||
def test_modify(context):
 | 
			
		||||
    with open(context.subfile, 'wb', buffering=0) as ofile:
 | 
			
		||||
        ofile.write(b'text')
 | 
			
		||||
    context.check_event(InotifyFileModifiedEvent, context.subfile)
 | 
			
		||||
    while True:
 | 
			
		||||
        try:
 | 
			
		||||
            context.observer.modify_queue.get(timeout=0.1)
 | 
			
		||||
            context.missing_events += 1
 | 
			
		||||
        except Empty:
 | 
			
		||||
            break
 | 
			
		||||
    rename(context.subfile, context.subfile+'_new')
 | 
			
		||||
    context.check_event(InotifyDirModifiedEvent, context.subdir)
 | 
			
		||||
    assert context.check_any()
 | 
			
		||||
 | 
			
		||||
def test_move(context):
 | 
			
		||||
    rename(context.subdir, context.subdir+'_new')
 | 
			
		||||
    context.check_event(InotifyDirMovedEvent, context.subdir)
 | 
			
		||||
    context.check_event(InotifyFileMovedEvent, context.subfile)
 | 
			
		||||
    assert context.check_any()
 | 
			
		||||
 | 
			
		||||
def test_delete(context):
 | 
			
		||||
    rmtree(context.subdir)
 | 
			
		||||
    context.check_event(InotifyFileDeletedEvent, context.subfile)
 | 
			
		||||
    context.check_event(InotifyDirDeletedEvent, context.subdir)
 | 
			
		||||
    assert context.check_any()
 | 
			
		||||
 | 
			
		||||
def test_paths(context):
 | 
			
		||||
    context.observer.paths = context.subdir
 | 
			
		||||
    newdir = context.create_random_folder(context.workdir)
 | 
			
		||||
    newfile = context.create_random_file(context.subdir, '.txt')
 | 
			
		||||
    context.check_event(InotifyDirModifiedEvent, context.subdir)
 | 
			
		||||
    context.check_event(InotifyFileCreatedEvent, newfile)
 | 
			
		||||
    context.observer.paths = [newdir, newfile]
 | 
			
		||||
    remove(newfile)
 | 
			
		||||
    context.check_event(InotifyFileDeletedEvent, newfile)
 | 
			
		||||
    assert context.check_any()
 | 
			
		||||
    context.observer.paths = context.workdir
 | 
			
		||||
 | 
			
		||||
def test_patterns(context):
 | 
			
		||||
    context.observer.patterns = ['*.txt']
 | 
			
		||||
    context.create_random_file(context.subdir, '.bin')
 | 
			
		||||
    newfile = context.create_random_file(context.subdir, '.txt')
 | 
			
		||||
    context.check_event(InotifyFileCreatedEvent, newfile)
 | 
			
		||||
    context.check_empty(InotifyFileCreatedEvent)
 | 
			
		||||
    assert context.check_any()
 | 
			
		||||
    context.observer.patterns = None
 | 
			
		||||
 | 
			
		||||
def test_exclude(context):
 | 
			
		||||
    context.observer.exclude = ['*.txt']
 | 
			
		||||
    context.create_random_file(context.subdir, '.txt')
 | 
			
		||||
    newfile = context.create_random_file(context.subdir, '.bin')
 | 
			
		||||
    context.check_event(InotifyFileCreatedEvent, newfile)
 | 
			
		||||
    context.check_empty(InotifyFileCreatedEvent)
 | 
			
		||||
    assert context.check_any()
 | 
			
		||||
    context.observer.exclude = None
 | 
			
		||||
 | 
			
		||||
def test_stress(context):
 | 
			
		||||
    newfile = []
 | 
			
		||||
    for i in range(1024):
 | 
			
		||||
        newfile.append(context.create_random_file(context.subdir, '.txt'))
 | 
			
		||||
    for i in range(1024):
 | 
			
		||||
        context.check_event(InotifyFileCreatedEvent, newfile[i])
 | 
			
		||||
    assert context.check_any()
 | 
			
		||||
							
								
								
									
										27
									
								
								tfw/internals/lazy.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								tfw/internals/lazy.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,27 @@
 | 
			
		||||
from functools import update_wrapper, wraps
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class lazy_property:
 | 
			
		||||
    """
 | 
			
		||||
    Decorator that replaces a function with the value
 | 
			
		||||
    it calculates on the first call.
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, func):
 | 
			
		||||
        self.func = func
 | 
			
		||||
        update_wrapper(self, func)
 | 
			
		||||
 | 
			
		||||
    def __get__(self, instance, owner):
 | 
			
		||||
        if instance is None:
 | 
			
		||||
            return self  # avoids potential __new__ TypeError
 | 
			
		||||
        value = self.func(instance)
 | 
			
		||||
        setattr(instance, self.func.__name__, value)
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def lazy_factory(fun):
 | 
			
		||||
    class wrapper:
 | 
			
		||||
        @wraps(fun)
 | 
			
		||||
        @lazy_property
 | 
			
		||||
        def instance(self):  # pylint: disable=no-self-use
 | 
			
		||||
            return fun()
 | 
			
		||||
    return wrapper()
 | 
			
		||||
							
								
								
									
										4
									
								
								tfw/internals/networking/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								tfw/internals/networking/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,4 @@
 | 
			
		||||
from .serialization import serialize_tfw_msg, deserialize_tfw_msg, with_deserialize_tfw_msg, message_bytes
 | 
			
		||||
from .server_connector import ServerUplinkConnector, ServerDownlinkConnector, ServerConnector
 | 
			
		||||
from .event_handler_connector import EventHandlerConnector
 | 
			
		||||
from .scope import Scope
 | 
			
		||||
							
								
								
									
										48
									
								
								tfw/internals/networking/event_handler_connector.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								tfw/internals/networking/event_handler_connector.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,48 @@
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
import zmq
 | 
			
		||||
from zmq.eventloop.zmqstream import ZMQStream
 | 
			
		||||
 | 
			
		||||
from .serialization import serialize_tfw_msg, with_deserialize_tfw_msg
 | 
			
		||||
 | 
			
		||||
LOG = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EventHandlerDownlinkConnector:
 | 
			
		||||
    def __init__(self, bind_addr):
 | 
			
		||||
        self._zmq_pull_socket = zmq.Context.instance().socket(zmq.PULL)
 | 
			
		||||
        self._zmq_pull_socket.setsockopt(zmq.RCVHWM, 0)
 | 
			
		||||
        self._zmq_pull_stream = ZMQStream(self._zmq_pull_socket)
 | 
			
		||||
        self._zmq_pull_socket.bind(bind_addr)
 | 
			
		||||
        LOG.debug('Pull socket bound to %s', bind_addr)
 | 
			
		||||
 | 
			
		||||
    def register_callback(self, callback):
 | 
			
		||||
        callback = with_deserialize_tfw_msg(callback)
 | 
			
		||||
        self._zmq_pull_stream.on_recv(callback)
 | 
			
		||||
 | 
			
		||||
    def close(self):
 | 
			
		||||
        self._zmq_pull_stream.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EventHandlerUplinkConnector:
 | 
			
		||||
    def __init__(self, bind_addr):
 | 
			
		||||
        self._zmq_pub_socket = zmq.Context.instance().socket(zmq.PUB)
 | 
			
		||||
        self._zmq_pub_socket.setsockopt(zmq.SNDHWM, 0)
 | 
			
		||||
        self._zmq_pub_socket.bind(bind_addr)
 | 
			
		||||
        LOG.debug('Pub socket bound to %s', bind_addr)
 | 
			
		||||
 | 
			
		||||
    def send_message(self, message: dict):
 | 
			
		||||
        self._zmq_pub_socket.send_multipart(serialize_tfw_msg(message))
 | 
			
		||||
 | 
			
		||||
    def close(self):
 | 
			
		||||
        self._zmq_pub_socket.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EventHandlerConnector(EventHandlerDownlinkConnector, EventHandlerUplinkConnector):
 | 
			
		||||
    def __init__(self, downlink_bind_addr, uplink_bind_addr):
 | 
			
		||||
        EventHandlerDownlinkConnector.__init__(self, downlink_bind_addr)
 | 
			
		||||
        EventHandlerUplinkConnector.__init__(self, uplink_bind_addr)
 | 
			
		||||
 | 
			
		||||
    def close(self):
 | 
			
		||||
        EventHandlerDownlinkConnector.close(self)
 | 
			
		||||
        EventHandlerUplinkConnector.close(self)
 | 
			
		||||
							
								
								
									
										7
									
								
								tfw/internals/networking/scope.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								tfw/internals/networking/scope.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,7 @@
 | 
			
		||||
from enum import Enum
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Scope(Enum):
 | 
			
		||||
    ZMQ = 'zmq'
 | 
			
		||||
    WEBSOCKET = 'websocket'
 | 
			
		||||
    BROADCAST = 'broadcast'
 | 
			
		||||
							
								
								
									
										104
									
								
								tfw/internals/networking/serialization.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										104
									
								
								tfw/internals/networking/serialization.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,104 @@
 | 
			
		||||
"""
 | 
			
		||||
TFW JSON message format
 | 
			
		||||
 | 
			
		||||
message:
 | 
			
		||||
{
 | 
			
		||||
    "key":     string,    # addressing
 | 
			
		||||
    "data":    {...},     # payload
 | 
			
		||||
    "trigger": string     # FSM trigger
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
ZeroMQ's sub-pub sockets use enveloped messages
 | 
			
		||||
(http://zguide.zeromq.org/page:all#Pub-Sub-Message-Envelopes)
 | 
			
		||||
and TFW also uses them internally. This means that on ZMQ sockets
 | 
			
		||||
we always send the messages key separately and then the actual
 | 
			
		||||
message (which contains the key as well) like so:
 | 
			
		||||
 | 
			
		||||
socket.send_multipart([message['key'], message])
 | 
			
		||||
 | 
			
		||||
The purpose of this module is abstracting away this low level behaviour.
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
from functools import wraps
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def serialize_tfw_msg(message):
 | 
			
		||||
    """
 | 
			
		||||
    Create TFW multipart data from message dict
 | 
			
		||||
    """
 | 
			
		||||
    return _serialize_all(message['key'], message)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def with_deserialize_tfw_msg(fun):
 | 
			
		||||
    @wraps(fun)
 | 
			
		||||
    def wrapper(message_parts):
 | 
			
		||||
        message = deserialize_tfw_msg(*message_parts)
 | 
			
		||||
        return fun(message)
 | 
			
		||||
    return wrapper
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def deserialize_tfw_msg(*args):
 | 
			
		||||
    """
 | 
			
		||||
    Return message from TFW multipart data
 | 
			
		||||
    """
 | 
			
		||||
    return _deserialize_all(*args)[1]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _serialize_all(*args):
 | 
			
		||||
    return tuple(
 | 
			
		||||
        _serialize_single(arg)
 | 
			
		||||
        for arg in args
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _deserialize_all(*args):
 | 
			
		||||
    return tuple(
 | 
			
		||||
        _deserialize_single(arg)
 | 
			
		||||
        for arg in args
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _serialize_single(data):
 | 
			
		||||
    """
 | 
			
		||||
    Return input as bytes
 | 
			
		||||
    (serialize input if it is JSON)
 | 
			
		||||
    """
 | 
			
		||||
    if not isinstance(data, str):
 | 
			
		||||
        data = message_bytes(data)
 | 
			
		||||
    return _encode_if_needed(data)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def message_bytes(message):
 | 
			
		||||
    return json.dumps(message, sort_keys=True).encode()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _deserialize_single(data):
 | 
			
		||||
    """
 | 
			
		||||
    Try parsing input as JSON, return it as
 | 
			
		||||
    string if parsing fails.
 | 
			
		||||
    """
 | 
			
		||||
    try:
 | 
			
		||||
        return json.loads(data)
 | 
			
		||||
    except ValueError:
 | 
			
		||||
        return _decode_if_needed(data)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _encode_if_needed(value):
 | 
			
		||||
    """
 | 
			
		||||
    Return input as bytes
 | 
			
		||||
    (encode if input is string)
 | 
			
		||||
    """
 | 
			
		||||
    if isinstance(value, str):
 | 
			
		||||
        value = value.encode('utf-8')
 | 
			
		||||
    return value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _decode_if_needed(value):
 | 
			
		||||
    """
 | 
			
		||||
    Return input as string
 | 
			
		||||
    (decode if input is bytes)
 | 
			
		||||
    """
 | 
			
		||||
    if isinstance(value, (bytes, bytearray)):
 | 
			
		||||
        value = value.decode('utf-8')
 | 
			
		||||
    return value
 | 
			
		||||
							
								
								
									
										65
									
								
								tfw/internals/networking/server_connector.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										65
									
								
								tfw/internals/networking/server_connector.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,65 @@
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
import zmq
 | 
			
		||||
from zmq.eventloop.zmqstream import ZMQStream
 | 
			
		||||
 | 
			
		||||
from .scope import Scope
 | 
			
		||||
from .serialization import serialize_tfw_msg, with_deserialize_tfw_msg
 | 
			
		||||
 | 
			
		||||
LOG = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ServerDownlinkConnector:
 | 
			
		||||
    def __init__(self, connect_addr):
 | 
			
		||||
        self.keys = []
 | 
			
		||||
        self._on_recv_callback = None
 | 
			
		||||
        self._zmq_sub_socket = zmq.Context.instance().socket(zmq.SUB)
 | 
			
		||||
        self._zmq_sub_socket.setsockopt(zmq.RCVHWM, 0)
 | 
			
		||||
        self._zmq_sub_socket.connect(connect_addr)
 | 
			
		||||
        self._zmq_sub_stream = ZMQStream(self._zmq_sub_socket)
 | 
			
		||||
 | 
			
		||||
    def subscribe(self, *keys):
 | 
			
		||||
        for key in keys:
 | 
			
		||||
            self._zmq_sub_socket.setsockopt_string(zmq.SUBSCRIBE, key)
 | 
			
		||||
            self.keys.append(key)
 | 
			
		||||
 | 
			
		||||
    def unsubscribe(self, *keys):
 | 
			
		||||
        for key in keys:
 | 
			
		||||
            self._zmq_sub_socket.setsockopt_string(zmq.UNSUBSCRIBE, key)
 | 
			
		||||
            self.keys.remove(key)
 | 
			
		||||
 | 
			
		||||
    def register_callback(self, callback):
 | 
			
		||||
        self._on_recv_callback = callback
 | 
			
		||||
        self._zmq_sub_stream.on_recv(with_deserialize_tfw_msg(self._on_recv))
 | 
			
		||||
 | 
			
		||||
    def _on_recv(self, message):
 | 
			
		||||
        key = message['key']
 | 
			
		||||
        if key in self.keys or '' in self.keys:
 | 
			
		||||
            self._on_recv_callback(message)
 | 
			
		||||
 | 
			
		||||
    def close(self):
 | 
			
		||||
        self._zmq_sub_stream.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ServerUplinkConnector:
 | 
			
		||||
    def __init__(self, connect_addr):
 | 
			
		||||
        self._zmq_push_socket = zmq.Context.instance().socket(zmq.PUSH)
 | 
			
		||||
        self._zmq_push_socket.setsockopt(zmq.SNDHWM, 0)
 | 
			
		||||
        self._zmq_push_socket.connect(connect_addr)
 | 
			
		||||
 | 
			
		||||
    def send_message(self, message, scope=Scope.ZMQ):
 | 
			
		||||
        message['scope'] = scope.value
 | 
			
		||||
        self._zmq_push_socket.send_multipart(serialize_tfw_msg(message))
 | 
			
		||||
 | 
			
		||||
    def close(self):
 | 
			
		||||
        self._zmq_push_socket.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ServerConnector(ServerDownlinkConnector, ServerUplinkConnector):
 | 
			
		||||
    def __init__(self, downlink_connect_addr, uplink_connect_addr):
 | 
			
		||||
        ServerDownlinkConnector.__init__(self, downlink_connect_addr)
 | 
			
		||||
        ServerUplinkConnector.__init__(self, uplink_connect_addr)
 | 
			
		||||
 | 
			
		||||
    def close(self):
 | 
			
		||||
        ServerDownlinkConnector.close(self)
 | 
			
		||||
        ServerUplinkConnector.close(self)
 | 
			
		||||
							
								
								
									
										1
									
								
								tfw/internals/server/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								tfw/internals/server/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1 @@
 | 
			
		||||
from .zmq_websocket_router import ZMQWebSocketRouter
 | 
			
		||||
							
								
								
									
										69
									
								
								tfw/internals/server/zmq_websocket_router.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								tfw/internals/server/zmq_websocket_router.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,69 @@
 | 
			
		||||
import json
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
from tornado.websocket import WebSocketHandler
 | 
			
		||||
 | 
			
		||||
from tfw.internals.networking import Scope
 | 
			
		||||
 | 
			
		||||
LOG = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ZMQWebSocketRouter(WebSocketHandler):
 | 
			
		||||
    # pylint: disable=abstract-method,attribute-defined-outside-init
 | 
			
		||||
    instances = set()
 | 
			
		||||
 | 
			
		||||
    def initialize(self, **kwargs):
 | 
			
		||||
        self.event_handler_connector = kwargs['event_handler_connector']
 | 
			
		||||
        self.tfw_router = TFWRouter(self.send_to_zmq, self.send_to_websockets)
 | 
			
		||||
 | 
			
		||||
    def send_to_zmq(self, message):
 | 
			
		||||
        self.event_handler_connector.send_message(message)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def send_to_websockets(cls, message):
 | 
			
		||||
        for instance in cls.instances:
 | 
			
		||||
            instance.write_message(message)
 | 
			
		||||
 | 
			
		||||
    def prepare(self):
 | 
			
		||||
        type(self).instances.add(self)
 | 
			
		||||
 | 
			
		||||
    def on_close(self):
 | 
			
		||||
        type(self).instances.remove(self)
 | 
			
		||||
 | 
			
		||||
    def open(self, *args, **kwargs):
 | 
			
		||||
        LOG.debug('WebSocket connection initiated!')
 | 
			
		||||
        self.event_handler_connector.register_callback(self.zmq_callback)
 | 
			
		||||
 | 
			
		||||
    def zmq_callback(self, message):
 | 
			
		||||
        LOG.debug('Received on ZMQ pull socket: %s', message)
 | 
			
		||||
        self.tfw_router.route(message)
 | 
			
		||||
 | 
			
		||||
    def on_message(self, message):
 | 
			
		||||
        message = json.loads(message)
 | 
			
		||||
        LOG.debug('Received on WebSocket: %s', message)
 | 
			
		||||
        self.tfw_router.route(message)
 | 
			
		||||
 | 
			
		||||
    # much secure, very cors, wow
 | 
			
		||||
    def check_origin(self, origin):
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TFWRouter:
 | 
			
		||||
    def __init__(self, send_to_zmq, send_to_websockets):
 | 
			
		||||
        self.send_to_zmq = send_to_zmq
 | 
			
		||||
        self.send_to_websockets = send_to_websockets
 | 
			
		||||
 | 
			
		||||
    def route(self, message):
 | 
			
		||||
        scope = Scope(message.pop('scope', 'zmq'))
 | 
			
		||||
 | 
			
		||||
        routing_table = {
 | 
			
		||||
            Scope.ZMQ: self.send_to_zmq,
 | 
			
		||||
            Scope.WEBSOCKET: self.send_to_websockets,
 | 
			
		||||
            Scope.BROADCAST: self.broadcast
 | 
			
		||||
        }
 | 
			
		||||
        action = routing_table[scope]
 | 
			
		||||
        action(message)
 | 
			
		||||
 | 
			
		||||
    def broadcast(self, message):
 | 
			
		||||
        self.send_to_zmq(message)
 | 
			
		||||
        self.send_to_websockets(message)
 | 
			
		||||
		Reference in New Issue
	
	Block a user