Simplify package structure

This commit is contained in:
Kristóf Tóth
2019-07-24 15:50:41 +02:00
parent a23224aced
commit 52399f413c
79 changed files with 22 additions and 24 deletions

View File

View File

@ -0,0 +1,35 @@
from functools import partial
from .lazy import lazy_property
class CallbackMixin:
@lazy_property
def _callbacks(self):
# pylint: disable=no-self-use
return []
def subscribe_callback(self, callback, *args, **kwargs):
"""
Subscribe a callable to invoke once an event is triggered.
:param callback: callable to be executed on events
:param args: arguments passed to callable
:param kwargs: kwargs passed to callable
"""
fun = partial(callback, *args, **kwargs)
self._callbacks.append(fun)
def subscribe_callbacks(self, *callbacks):
"""
Subscribe a list of callbacks to incoke once an event is triggered.
:param callbacks: callbacks to be subscribed
"""
for callback in callbacks:
self.subscribe_callback(callback)
def unsubscribe_callback(self, callback):
self._callbacks.remove(callback)
def _execute_callbacks(self, *args, **kwargs):
for callback in self._callbacks:
callback(*args, **kwargs)

107
tfw/internals/crypto.py Normal file
View File

@ -0,0 +1,107 @@
from functools import wraps
from base64 import b64encode, b64decode
from copy import deepcopy
from hashlib import md5
from os import urandom, chmod
from os.path import exists
from stat import S_IRUSR, S_IWUSR, S_IXUSR
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.hashes import SHA256
from cryptography.hazmat.primitives.hmac import HMAC as _HMAC
from cryptography.exceptions import InvalidSignature
from tfw.internals.networking import message_bytes
from tfw.internals.lazy import lazy_property
from tfw.config import TFWENV
def message_checksum(message):
return md5(message_bytes(message)).hexdigest()
def sign_message(key, message):
message.pop('scope', None)
message.pop('signature', None)
signature = message_signature(key, message)
message['signature'] = b64encode(signature).decode()
def message_signature(key, message):
return HMAC(key, message_bytes(message)).signature
def verify_message(key, message):
message.pop('scope', None)
message = deepcopy(message)
try:
signature_b64 = message.pop('signature')
signature = b64decode(signature_b64)
actual_signature = message_signature(key, message)
return signature == actual_signature
except KeyError:
return False
class KeyManager:
def __init__(self):
self.keyfile = TFWENV.AUTH_KEY
if not exists(self.keyfile):
self._init_auth_key()
@lazy_property
def auth_key(self):
with open(self.keyfile, 'rb') as ifile:
return ifile.read()
def _init_auth_key(self):
key = self.generate_key()
with open(self.keyfile, 'wb') as ofile:
ofile.write(key)
self._chmod_700_keyfile()
return key
@staticmethod
def generate_key():
return urandom(32)
def _chmod_700_keyfile(self):
chmod(self.keyfile, S_IRUSR | S_IWUSR | S_IXUSR)
class HMAC:
def __init__(self, key, message):
self.key = key
self.message = message
self._hmac = _HMAC(
key=key,
algorithm=SHA256(),
backend=default_backend()
)
def _reload_if_finalized(f):
# pylint: disable=no-self-argument,not-callable
@wraps(f)
def wrapped(instance, *args, **kwargs):
if getattr(instance, '_finalized', False):
instance.__init__(instance.key, instance.message)
ret_val = f(instance, *args, **kwargs)
setattr(instance, '_finalized', True)
return ret_val
return wrapped
@property
@_reload_if_finalized
def signature(self):
self._hmac.update(self.message)
signature = self._hmac.finalize()
return signature
@_reload_if_finalized
def verify(self, signature):
self._hmac.update(self.message)
try:
self._hmac.verify(signature)
return True
except InvalidSignature:
return False

View File

@ -0,0 +1,3 @@
from .event_handler_factory_base import EventHandlerFactoryBase
from .event_handler import EventHandler
from .fsm_aware_event_handler import FSMAwareEventHandler

View File

@ -0,0 +1,27 @@
class EventHandler:
_instances = set()
def __init__(self, server_connector):
type(self)._instances.add(self)
self.server_connector = server_connector
def start(self):
self.server_connector.register_callback(self._event_callback)
def _event_callback(self, message):
self.handle_event(message, self.server_connector)
def handle_event(self, message, server_connector):
raise NotImplementedError()
@classmethod
def stop_all_instances(cls):
for instance in cls._instances:
instance.stop()
def stop(self):
self.server_connector.close()
self.cleanup()
def cleanup(self):
pass

View File

@ -0,0 +1,68 @@
from contextlib import suppress
from .event_handler import EventHandler
class EventHandlerFactoryBase:
def build(self, handler_stub, *, keys=None, event_handler_type=EventHandler):
builder = EventHandlerBuilder(handler_stub, keys, event_handler_type)
server_connector = self._build_server_connector()
event_handler = builder.build(server_connector)
handler_stub.server_connector = server_connector
with suppress(AttributeError):
handler_stub.start()
event_handler.start()
return event_handler
def _build_server_connector(self):
raise NotImplementedError()
class EventHandlerBuilder:
def __init__(self, event_handler, supplied_keys, event_handler_type):
self._analyzer = HandlerStubAnalyzer(event_handler, supplied_keys)
self._event_handler_type = event_handler_type
def build(self, server_connector):
event_handler = self._event_handler_type(server_connector)
server_connector.subscribe(*self._try_get_keys(event_handler))
event_handler.handle_event = self._analyzer.handle_event
with suppress(AttributeError):
event_handler.cleanup = self._analyzer.cleanup
return event_handler
def _try_get_keys(self, event_handler):
try:
return self._analyzer.keys
except ValueError:
with suppress(AttributeError):
return event_handler.keys
raise
class HandlerStubAnalyzer:
def __init__(self, event_handler, supplied_keys):
self._event_handler = event_handler
self._supplied_keys = supplied_keys
@property
def keys(self):
if self._supplied_keys is None:
try:
return self._event_handler.keys
except AttributeError:
raise ValueError('No keys supplied!')
return self._supplied_keys
@property
def handle_event(self):
try:
return self._event_handler.handle_event
except AttributeError:
if callable(self._event_handler):
return self._event_handler
raise ValueError('Object must implement handle_event or be a callable!')
@property
def cleanup(self):
return self._event_handler.cleanup

View File

@ -0,0 +1,37 @@
import logging
from tfw.internals.crypto import KeyManager, verify_message
LOG = logging.getLogger(__name__)
class FSMAware:
keys = ['fsm_update']
"""
Base class for stuff that has to be aware of the framework FSM.
This is done by processing 'fsm_update' messages.
"""
def __init__(self):
self.fsm_state = None
self.fsm_in_accepted_state = False
self.fsm_event_log = []
self._auth_key = KeyManager().auth_key
def process_message(self, message):
if message['key'] == 'fsm_update':
if verify_message(self._auth_key, message):
self._handle_fsm_update(message)
def _handle_fsm_update(self, message):
try:
new_state = message['current_state']
if self.fsm_state != new_state:
self.handle_fsm_step(message)
self.fsm_state = new_state
self.fsm_in_accepted_state = message['in_accepted_state']
self.fsm_event_log.append(message)
except KeyError:
LOG.error('Invalid fsm_update message received!')
def handle_fsm_step(self, message):
pass

View File

@ -0,0 +1,19 @@
from .event_handler import EventHandler
from .fsm_aware import FSMAware
class FSMAwareEventHandler(EventHandler, FSMAware):
# pylint: disable=abstract-method
"""
Abstract base class for EventHandlers which automatically
keep track of the state of the TFW FSM.
"""
def __init__(self, server_connector):
EventHandler.__init__(self, server_connector)
FSMAware.__init__(self)
def _event_callback(self, message):
self.process_message(message)
def handle_fsm_step(self, message):
self.handle_event(message, self.server_connector)

View File

@ -0,0 +1,190 @@
# pylint: disable=redefined-outer-name,attribute-defined-outside-init
from secrets import token_urlsafe
from random import randint
import pytest
from .event_handler_factory_base import EventHandlerFactoryBase
from .event_handler import EventHandler
class MockEventHandlerFactory(EventHandlerFactoryBase):
def _build_server_connector(self):
return MockServerConnector()
class MockServerConnector:
def __init__(self):
self.keys = []
self._on_message = None
def simulate_message(self, message):
self._on_message(message)
def register_callback(self, callback):
self._on_message = callback
def subscribe(self, *keys):
self.keys.extend(keys)
def unsubscribe(self, *keys):
for key in keys:
self.keys.remove(key)
def send_message(self, message, scope=None):
pass
def close(self):
pass
class MockEventHandlerStub:
def __init__(self):
self.server_connector = None
self.last_message = None
self.cleaned_up = False
self.started = False
def start(self):
self.started = True
def cleanup(self):
self.cleaned_up = True
class MockEventHandler(MockEventHandlerStub):
# pylint: disable=unused-argument
def handle_event(self, message, server_connector):
self.last_message = message
class MockCallable(MockEventHandlerStub):
def __call__(self, message, server_connector):
self.last_message = message
@pytest.fixture
def test_msg():
yield token_urlsafe(randint(16, 64))
@pytest.fixture
def test_keys():
yield [
token_urlsafe(randint(2, 8))
for _ in range(randint(16, 32))
]
def test_build_from_object(test_keys, test_msg):
mock_eh = MockEventHandlerStub()
def handle_event(message, server_connector):
raise RuntimeError(message, server_connector.keys)
mock_eh.handle_event = handle_event
assert not mock_eh.started
eh = MockEventHandlerFactory().build(mock_eh, keys=test_keys)
assert mock_eh.started
assert mock_eh.server_connector is eh.server_connector
with pytest.raises(RuntimeError) as err:
eh.server_connector.simulate_message(test_msg)
msg, keys = err.args
assert msg == test_msg
assert keys == test_keys
assert not mock_eh.cleaned_up
eh.stop()
assert mock_eh.cleaned_up
def test_build_from_object_with_keys(test_keys, test_msg):
mock_eh = MockEventHandler()
mock_eh.keys = test_keys
assert not mock_eh.started
eh = MockEventHandlerFactory().build(mock_eh)
assert mock_eh.server_connector.keys == test_keys
assert eh.server_connector is mock_eh.server_connector
assert mock_eh.started
assert not mock_eh.last_message
eh.server_connector.simulate_message(test_msg)
assert mock_eh.last_message == test_msg
assert not mock_eh.cleaned_up
EventHandler.stop_all_instances()
assert mock_eh.cleaned_up
def test_build_from_simple_object(test_keys, test_msg):
class SimpleMockEventHandler:
# pylint: disable=no-self-use
def handle_event(self, message, server_connector):
raise RuntimeError(message, server_connector)
mock_eh = SimpleMockEventHandler()
eh = MockEventHandlerFactory().build(mock_eh, keys=test_keys)
with pytest.raises(RuntimeError) as err:
eh.server_connector.simulate_message(test_msg)
msg, keys = err.args
assert msg == test_msg
assert keys == test_keys
def test_build_from_callable(test_keys, test_msg):
mock_eh = MockCallable()
assert not mock_eh.started
eh = MockEventHandlerFactory().build(mock_eh, keys=test_keys)
assert mock_eh.started
assert mock_eh.server_connector is eh.server_connector
assert eh.server_connector.keys == test_keys
assert not mock_eh.last_message
eh.server_connector.simulate_message(test_msg)
assert mock_eh.last_message == test_msg
assert not mock_eh.cleaned_up
eh.stop()
assert mock_eh.cleaned_up
def test_build_from_function(test_keys, test_msg):
def some_function(message, server_connector):
raise RuntimeError(message, server_connector.keys)
eh = MockEventHandlerFactory().build(some_function, keys=test_keys)
assert eh.server_connector.keys == test_keys
with pytest.raises(RuntimeError) as err:
eh.server_connector.simulate_message(test_msg)
msg, keys = err.args
assert msg == test_msg
assert keys == test_keys
def test_build_from_lambda(test_keys, test_msg):
def assert_messages_equal(msg):
assert msg == test_msg
fun = lambda msg, sc: assert_messages_equal(msg)
eh = MockEventHandlerFactory().build(fun, keys=test_keys)
eh.server_connector.simulate_message(test_msg)
def test_build_raises_if_no_key(test_keys):
eh = MockEventHandler()
with pytest.raises(ValueError):
MockEventHandlerFactory().build(eh)
def handle_event(*_):
pass
with pytest.raises(ValueError):
MockEventHandlerFactory().build(handle_event)
with pytest.raises(ValueError):
MockEventHandlerFactory().build(lambda msg, sc: None)
WithKeysEventHandler = EventHandler
WithKeysEventHandler.keys = test_keys
MockEventHandlerFactory().build(eh, event_handler_type=WithKeysEventHandler)
eh.keys = test_keys
MockEventHandlerFactory().build(eh)

View File

@ -0,0 +1,6 @@
from .inotify import InotifyObserver
from .inotify import (
InotifyFileCreatedEvent, InotifyFileModifiedEvent, InotifyFileMovedEvent,
InotifyFileDeletedEvent, InotifyDirCreatedEvent, InotifyDirModifiedEvent,
InotifyDirMovedEvent, InotifyDirDeletedEvent
)

View File

@ -0,0 +1,189 @@
# pylint: disable=too-few-public-methods
from typing import Iterable
from time import time
from os.path import abspath, dirname, isdir
from watchdog.observers import Observer
from watchdog.events import FileSystemMovedEvent, PatternMatchingEventHandler
from watchdog.events import (
FileCreatedEvent, FileModifiedEvent, FileMovedEvent, FileDeletedEvent,
DirCreatedEvent, DirModifiedEvent, DirMovedEvent, DirDeletedEvent
)
class InotifyEvent:
def __init__(self, src_path):
self.date = time()
self.src_path = src_path
def __str__(self):
return self.__repr__()
def __repr__(self):
return f'{self.__class__.__name__}({self.src_path})'
class InotifyMovedEvent(InotifyEvent):
def __init__(self, src_path, dest_path):
self.dest_path = dest_path
super().__init__(src_path)
def __repr__(self):
return f'{self.__class__.__name__}({self.src_path}, {self.dest_path})'
class InotifyFileCreatedEvent(InotifyEvent):
pass
class InotifyFileModifiedEvent(InotifyEvent):
pass
class InotifyFileMovedEvent(InotifyMovedEvent):
pass
class InotifyFileDeletedEvent(InotifyEvent):
pass
class InotifyDirCreatedEvent(InotifyEvent):
pass
class InotifyDirModifiedEvent(InotifyEvent):
pass
class InotifyDirMovedEvent(InotifyMovedEvent):
pass
class InotifyDirDeletedEvent(InotifyEvent):
pass
class InotifyObserver:
def __init__(self, path, patterns=None, exclude=None, recursive=False):
self._files = []
self._paths = path
self._patterns = patterns or []
self._exclude = exclude
self._recursive = recursive
self._observer = Observer()
self._reset()
def _reset(self):
if isinstance(self._paths, str):
self._paths = [self._paths]
if isinstance(self._paths, Iterable):
self._extract_files_from_paths()
else:
raise ValueError('Expected one or more string paths.')
patterns = self._files+self.patterns
handler = PatternMatchingEventHandler(patterns if patterns else None, self.exclude)
handler.on_any_event = self._dispatch_event
self._observer.unschedule_all()
for path in self.paths:
self._observer.schedule(handler, path, self._recursive)
def _extract_files_from_paths(self):
files, paths = [], []
for path in self._paths:
path = abspath(path)
if isdir(path):
paths.append(path)
else:
paths.append(dirname(path))
files.append(path)
self._files, self._paths = files, paths
@property
def paths(self):
return self._paths
@paths.setter
def paths(self, paths):
self._paths = paths
self._reset()
@property
def patterns(self):
return self._patterns
@patterns.setter
def patterns(self, patterns):
self._patterns = patterns or []
self._reset()
@property
def exclude(self):
return self._exclude
@exclude.setter
def exclude(self, exclude):
self._exclude = exclude
self._reset()
def start(self):
self._observer.start()
def stop(self):
self._observer.stop()
self._observer.join()
def _dispatch_event(self, event):
event_to_action = {
InotifyFileCreatedEvent : self.on_created,
InotifyFileModifiedEvent : self.on_modified,
InotifyFileMovedEvent : self.on_moved,
InotifyFileDeletedEvent : self.on_deleted,
InotifyDirCreatedEvent : self.on_created,
InotifyDirModifiedEvent : self.on_modified,
InotifyDirMovedEvent : self.on_moved,
InotifyDirDeletedEvent : self.on_deleted
}
event = self._transform_event(event)
self.on_any_event(event)
event_to_action[type(event)](event)
@staticmethod
def _transform_event(event):
watchdog_to_inotify = {
FileCreatedEvent : InotifyFileCreatedEvent,
FileModifiedEvent : InotifyFileModifiedEvent,
FileMovedEvent : InotifyFileMovedEvent,
FileDeletedEvent : InotifyFileDeletedEvent,
DirCreatedEvent : InotifyDirCreatedEvent,
DirModifiedEvent : InotifyDirModifiedEvent,
DirMovedEvent : InotifyDirMovedEvent,
DirDeletedEvent : InotifyDirDeletedEvent
}
try:
cls = watchdog_to_inotify[type(event)]
except KeyError:
raise NameError('Watchdog API returned an unknown event.')
if isinstance(event, FileSystemMovedEvent):
return cls(event.src_path, event.dest_path)
return cls(event.src_path)
def on_any_event(self, event):
pass
def on_created(self, event):
pass
def on_modified(self, event):
pass
def on_moved(self, event):
pass
def on_deleted(self, event):
pass

View File

@ -0,0 +1,179 @@
# pylint: disable=redefined-outer-name
from queue import Empty, Queue
from secrets import token_urlsafe
from pathlib import Path
from shutil import rmtree
from os.path import join
from os import mkdir, remove, rename
from tempfile import TemporaryDirectory
from contextlib import suppress
import watchdog
import pytest
from .inotify import InotifyObserver
from .inotify import (
InotifyFileCreatedEvent, InotifyFileModifiedEvent, InotifyFileMovedEvent,
InotifyFileDeletedEvent, InotifyDirCreatedEvent, InotifyDirModifiedEvent,
InotifyDirMovedEvent, InotifyDirDeletedEvent
)
with suppress(AttributeError):
watchdog.observers.inotify_buffer.InotifyBuffer.delay = 0
class InotifyContext:
def __init__(self, workdir, subdir, subfile, observer):
self.missing_events = 0
self.workdir = workdir
self.subdir = subdir
self.subfile = subfile
self.observer = observer
self.event_to_queue = {
InotifyFileCreatedEvent : self.observer.create_queue,
InotifyFileModifiedEvent : self.observer.modify_queue,
InotifyFileMovedEvent : self.observer.move_queue,
InotifyFileDeletedEvent : self.observer.delete_queue,
InotifyDirCreatedEvent : self.observer.create_queue,
InotifyDirModifiedEvent : self.observer.modify_queue,
InotifyDirMovedEvent : self.observer.move_queue,
InotifyDirDeletedEvent : self.observer.delete_queue
}
def create_random_file(self, dirname, extension):
filename = self.join(f'{dirname}/{generate_name()}{extension}')
Path(filename).touch()
return filename
def create_random_folder(self, basepath):
dirname = self.join(f'{basepath}/{generate_name()}')
mkdir(dirname)
return dirname
def join(self, path):
return join(self.workdir, path)
def check_event(self, event_type, path):
self.missing_events += 1
event = self.event_to_queue[event_type].get(timeout=0.1)
assert isinstance(event, event_type)
assert event.src_path == path
return event
def check_empty(self, event_type):
with pytest.raises(Empty):
self.event_to_queue[event_type].get(timeout=0.1)
def check_any(self):
attrs = self.observer.__dict__.values()
total = sum([q.qsize() for q in attrs if isinstance(q, Queue)])
return total+self.missing_events == len(self.observer.any_list)
class InotifyTestObserver(InotifyObserver):
def __init__(self, paths, patterns=None, exclude=None, recursive=False):
self.any_list = []
self.create_queue, self.modify_queue, self.move_queue, self.delete_queue = [Queue() for _ in range(4)]
super().__init__(paths, patterns, exclude, recursive)
def on_any_event(self, event):
self.any_list.append(event)
def on_created(self, event):
self.create_queue.put(event)
def on_modified(self, event):
self.modify_queue.put(event)
def on_moved(self, event):
self.move_queue.put(event)
def on_deleted(self, event):
self.delete_queue.put(event)
def generate_name():
return token_urlsafe(16)
@pytest.fixture()
def context():
with TemporaryDirectory() as workdir:
subdir = join(workdir, generate_name())
subfile = join(subdir, generate_name()+'.txt')
mkdir(subdir)
Path(subfile).touch()
monitor = InotifyTestObserver(workdir, recursive=True)
monitor.start()
yield InotifyContext(workdir, subdir, subfile, monitor)
def test_create(context):
newfile = context.create_random_file(context.workdir, '.txt')
context.check_event(InotifyFileCreatedEvent, newfile)
newdir = context.create_random_folder(context.workdir)
context.check_event(InotifyDirCreatedEvent, newdir)
assert context.check_any()
def test_modify(context):
with open(context.subfile, 'wb', buffering=0) as ofile:
ofile.write(b'text')
context.check_event(InotifyFileModifiedEvent, context.subfile)
while True:
try:
context.observer.modify_queue.get(timeout=0.1)
context.missing_events += 1
except Empty:
break
rename(context.subfile, context.subfile+'_new')
context.check_event(InotifyDirModifiedEvent, context.subdir)
assert context.check_any()
def test_move(context):
rename(context.subdir, context.subdir+'_new')
context.check_event(InotifyDirMovedEvent, context.subdir)
context.check_event(InotifyFileMovedEvent, context.subfile)
assert context.check_any()
def test_delete(context):
rmtree(context.subdir)
context.check_event(InotifyFileDeletedEvent, context.subfile)
context.check_event(InotifyDirDeletedEvent, context.subdir)
assert context.check_any()
def test_paths(context):
context.observer.paths = context.subdir
newdir = context.create_random_folder(context.workdir)
newfile = context.create_random_file(context.subdir, '.txt')
context.check_event(InotifyDirModifiedEvent, context.subdir)
context.check_event(InotifyFileCreatedEvent, newfile)
context.observer.paths = [newdir, newfile]
remove(newfile)
context.check_event(InotifyFileDeletedEvent, newfile)
assert context.check_any()
context.observer.paths = context.workdir
def test_patterns(context):
context.observer.patterns = ['*.txt']
context.create_random_file(context.subdir, '.bin')
newfile = context.create_random_file(context.subdir, '.txt')
context.check_event(InotifyFileCreatedEvent, newfile)
context.check_empty(InotifyFileCreatedEvent)
assert context.check_any()
context.observer.patterns = None
def test_exclude(context):
context.observer.exclude = ['*.txt']
context.create_random_file(context.subdir, '.txt')
newfile = context.create_random_file(context.subdir, '.bin')
context.check_event(InotifyFileCreatedEvent, newfile)
context.check_empty(InotifyFileCreatedEvent)
assert context.check_any()
context.observer.exclude = None
def test_stress(context):
newfile = []
for i in range(1024):
newfile.append(context.create_random_file(context.subdir, '.txt'))
for i in range(1024):
context.check_event(InotifyFileCreatedEvent, newfile[i])
assert context.check_any()

27
tfw/internals/lazy.py Normal file
View File

@ -0,0 +1,27 @@
from functools import update_wrapper, wraps
class lazy_property:
"""
Decorator that replaces a function with the value
it calculates on the first call.
"""
def __init__(self, func):
self.func = func
update_wrapper(self, func)
def __get__(self, instance, owner):
if instance is None:
return self # avoids potential __new__ TypeError
value = self.func(instance)
setattr(instance, self.func.__name__, value)
return value
def lazy_factory(fun):
class wrapper:
@wraps(fun)
@lazy_property
def instance(self): # pylint: disable=no-self-use
return fun()
return wrapper()

View File

@ -0,0 +1,4 @@
from .serialization import serialize_tfw_msg, deserialize_tfw_msg, with_deserialize_tfw_msg, message_bytes
from .server_connector import ServerUplinkConnector, ServerDownlinkConnector, ServerConnector
from .event_handler_connector import EventHandlerConnector
from .scope import Scope

View File

@ -0,0 +1,48 @@
import logging
import zmq
from zmq.eventloop.zmqstream import ZMQStream
from .serialization import serialize_tfw_msg, with_deserialize_tfw_msg
LOG = logging.getLogger(__name__)
class EventHandlerDownlinkConnector:
def __init__(self, bind_addr):
self._zmq_pull_socket = zmq.Context.instance().socket(zmq.PULL)
self._zmq_pull_socket.setsockopt(zmq.RCVHWM, 0)
self._zmq_pull_stream = ZMQStream(self._zmq_pull_socket)
self._zmq_pull_socket.bind(bind_addr)
LOG.debug('Pull socket bound to %s', bind_addr)
def register_callback(self, callback):
callback = with_deserialize_tfw_msg(callback)
self._zmq_pull_stream.on_recv(callback)
def close(self):
self._zmq_pull_stream.close()
class EventHandlerUplinkConnector:
def __init__(self, bind_addr):
self._zmq_pub_socket = zmq.Context.instance().socket(zmq.PUB)
self._zmq_pub_socket.setsockopt(zmq.SNDHWM, 0)
self._zmq_pub_socket.bind(bind_addr)
LOG.debug('Pub socket bound to %s', bind_addr)
def send_message(self, message: dict):
self._zmq_pub_socket.send_multipart(serialize_tfw_msg(message))
def close(self):
self._zmq_pub_socket.close()
class EventHandlerConnector(EventHandlerDownlinkConnector, EventHandlerUplinkConnector):
def __init__(self, downlink_bind_addr, uplink_bind_addr):
EventHandlerDownlinkConnector.__init__(self, downlink_bind_addr)
EventHandlerUplinkConnector.__init__(self, uplink_bind_addr)
def close(self):
EventHandlerDownlinkConnector.close(self)
EventHandlerUplinkConnector.close(self)

View File

@ -0,0 +1,7 @@
from enum import Enum
class Scope(Enum):
ZMQ = 'zmq'
WEBSOCKET = 'websocket'
BROADCAST = 'broadcast'

View File

@ -0,0 +1,104 @@
"""
TFW JSON message format
message:
{
"key": string, # addressing
"data": {...}, # payload
"trigger": string # FSM trigger
}
ZeroMQ's sub-pub sockets use enveloped messages
(http://zguide.zeromq.org/page:all#Pub-Sub-Message-Envelopes)
and TFW also uses them internally. This means that on ZMQ sockets
we always send the messages key separately and then the actual
message (which contains the key as well) like so:
socket.send_multipart([message['key'], message])
The purpose of this module is abstracting away this low level behaviour.
"""
import json
from functools import wraps
def serialize_tfw_msg(message):
"""
Create TFW multipart data from message dict
"""
return _serialize_all(message['key'], message)
def with_deserialize_tfw_msg(fun):
@wraps(fun)
def wrapper(message_parts):
message = deserialize_tfw_msg(*message_parts)
return fun(message)
return wrapper
def deserialize_tfw_msg(*args):
"""
Return message from TFW multipart data
"""
return _deserialize_all(*args)[1]
def _serialize_all(*args):
return tuple(
_serialize_single(arg)
for arg in args
)
def _deserialize_all(*args):
return tuple(
_deserialize_single(arg)
for arg in args
)
def _serialize_single(data):
"""
Return input as bytes
(serialize input if it is JSON)
"""
if not isinstance(data, str):
data = message_bytes(data)
return _encode_if_needed(data)
def message_bytes(message):
return json.dumps(message, sort_keys=True).encode()
def _deserialize_single(data):
"""
Try parsing input as JSON, return it as
string if parsing fails.
"""
try:
return json.loads(data)
except ValueError:
return _decode_if_needed(data)
def _encode_if_needed(value):
"""
Return input as bytes
(encode if input is string)
"""
if isinstance(value, str):
value = value.encode('utf-8')
return value
def _decode_if_needed(value):
"""
Return input as string
(decode if input is bytes)
"""
if isinstance(value, (bytes, bytearray)):
value = value.decode('utf-8')
return value

View File

@ -0,0 +1,65 @@
import logging
import zmq
from zmq.eventloop.zmqstream import ZMQStream
from .scope import Scope
from .serialization import serialize_tfw_msg, with_deserialize_tfw_msg
LOG = logging.getLogger(__name__)
class ServerDownlinkConnector:
def __init__(self, connect_addr):
self.keys = []
self._on_recv_callback = None
self._zmq_sub_socket = zmq.Context.instance().socket(zmq.SUB)
self._zmq_sub_socket.setsockopt(zmq.RCVHWM, 0)
self._zmq_sub_socket.connect(connect_addr)
self._zmq_sub_stream = ZMQStream(self._zmq_sub_socket)
def subscribe(self, *keys):
for key in keys:
self._zmq_sub_socket.setsockopt_string(zmq.SUBSCRIBE, key)
self.keys.append(key)
def unsubscribe(self, *keys):
for key in keys:
self._zmq_sub_socket.setsockopt_string(zmq.UNSUBSCRIBE, key)
self.keys.remove(key)
def register_callback(self, callback):
self._on_recv_callback = callback
self._zmq_sub_stream.on_recv(with_deserialize_tfw_msg(self._on_recv))
def _on_recv(self, message):
key = message['key']
if key in self.keys or '' in self.keys:
self._on_recv_callback(message)
def close(self):
self._zmq_sub_stream.close()
class ServerUplinkConnector:
def __init__(self, connect_addr):
self._zmq_push_socket = zmq.Context.instance().socket(zmq.PUSH)
self._zmq_push_socket.setsockopt(zmq.SNDHWM, 0)
self._zmq_push_socket.connect(connect_addr)
def send_message(self, message, scope=Scope.ZMQ):
message['scope'] = scope.value
self._zmq_push_socket.send_multipart(serialize_tfw_msg(message))
def close(self):
self._zmq_push_socket.close()
class ServerConnector(ServerDownlinkConnector, ServerUplinkConnector):
def __init__(self, downlink_connect_addr, uplink_connect_addr):
ServerDownlinkConnector.__init__(self, downlink_connect_addr)
ServerUplinkConnector.__init__(self, uplink_connect_addr)
def close(self):
ServerDownlinkConnector.close(self)
ServerUplinkConnector.close(self)

View File

@ -0,0 +1 @@
from .zmq_websocket_router import ZMQWebSocketRouter

View File

@ -0,0 +1,69 @@
import json
import logging
from tornado.websocket import WebSocketHandler
from tfw.internals.networking import Scope
LOG = logging.getLogger(__name__)
class ZMQWebSocketRouter(WebSocketHandler):
# pylint: disable=abstract-method,attribute-defined-outside-init
instances = set()
def initialize(self, **kwargs):
self.event_handler_connector = kwargs['event_handler_connector']
self.tfw_router = TFWRouter(self.send_to_zmq, self.send_to_websockets)
def send_to_zmq(self, message):
self.event_handler_connector.send_message(message)
@classmethod
def send_to_websockets(cls, message):
for instance in cls.instances:
instance.write_message(message)
def prepare(self):
type(self).instances.add(self)
def on_close(self):
type(self).instances.remove(self)
def open(self, *args, **kwargs):
LOG.debug('WebSocket connection initiated!')
self.event_handler_connector.register_callback(self.zmq_callback)
def zmq_callback(self, message):
LOG.debug('Received on ZMQ pull socket: %s', message)
self.tfw_router.route(message)
def on_message(self, message):
message = json.loads(message)
LOG.debug('Received on WebSocket: %s', message)
self.tfw_router.route(message)
# much secure, very cors, wow
def check_origin(self, origin):
return True
class TFWRouter:
def __init__(self, send_to_zmq, send_to_websockets):
self.send_to_zmq = send_to_zmq
self.send_to_websockets = send_to_websockets
def route(self, message):
scope = Scope(message.pop('scope', 'zmq'))
routing_table = {
Scope.ZMQ: self.send_to_zmq,
Scope.WEBSOCKET: self.send_to_websockets,
Scope.BROADCAST: self.broadcast
}
action = routing_table[scope]
action(message)
def broadcast(self, message):
self.send_to_zmq(message)
self.send_to_websockets(message)