Merge branch 'state_rewind'

This commit is contained in:
Kristóf Tóth 2018-07-20 15:19:53 +02:00
commit 878b6d3f2a
25 changed files with 473 additions and 136 deletions

View File

@ -34,10 +34,11 @@ ENV PYTHONPATH="/usr/local/lib" \
TFW_NGINX_CONF="/etc/nginx/nginx.conf" \
TFW_NGINX_DEFAULT="/etc/nginx/sites-enabled/default" \
TFW_NGINX_COMPONENTS="/etc/nginx/components" \
TFW_LIB_DIR="/usr/local/lib/" \
TFW_LIB_DIR="/usr/local/lib" \
TFW_TERMINADO_DIR="/tmp/terminado_server" \
TFW_FRONTEND_DIR="/srv/frontend" \
TFW_SERVER_DIR="/srv/.tfw" \
TFW_AUTH_KEY="/tmp/tfw-auth.key" \
TFW_HISTFILE="/home/${AVATAO_USER}/.bash_history" \
PROMPT_COMMAND="history -a"
@ -50,7 +51,7 @@ COPY supervisor/components/ ${TFW_SUPERVISORD_COMPONENTS}
COPY nginx/nginx.conf ${TFW_NGINX_CONF}
COPY nginx/default.conf ${TFW_NGINX_DEFAULT}
COPY nginx/components/ ${TFW_NGINX_COMPONENTS}
COPY lib LICENSE ${TFW_LIB_DIR}
COPY lib LICENSE ${TFW_LIB_DIR}/
COPY supervisor/tfw_server.py ${TFW_SERVER_DIR}/
RUN for dir in "${TFW_LIB_DIR}"/{tfw,tao,envvars} "/etc/nginx" "/etc/supervisor"; do \
@ -66,7 +67,7 @@ ONBUILD COPY ${BUILD_CONTEXT}/supervisor/ ${TFW_SUPERVISORD_COMPONENTS}
ONBUILD RUN for f in "${TFW_NGINX_DEFAULT}" ${TFW_NGINX_COMPONENTS}/*.conf; do \
envsubst "$(printenv | cut -d= -f1 | grep TFW_ | sed -e 's/^/$/g')" < $f > $f~ && mv $f~ $f ;\
done
ONBUILD VOLUME ["/etc/nginx", "/var/lib/nginx", "/var/log/nginx"]
ONBUILD VOLUME ["/etc/nginx", "/var/lib/nginx", "/var/log/nginx", "${TFW_LIB_DIR}/envvars", "${TFW_LIB_DIR}/tfw"]
ONBUILD COPY ${BUILD_CONTEXT}/frontend /data/
ONBUILD RUN test -z "${NOFRONTEND}" && cd /data && yarn install --frozen-lockfile || :

View File

@ -1,7 +1,7 @@
# Copyright (C) 2018 Avatao.com Innovative Learning Kft.
# All Rights Reserved. See LICENSE file for details.
from .event_handler_base import EventHandlerBase, TriggeredEventHandler, BroadcastingEventHandler
from .event_handler_base import EventHandlerBase, FSMAwareEventHandler, BroadcastingEventHandler
from .fsm_base import FSMBase
from .linear_fsm import LinearFSM
from .yaml_fsm import YamlFSM

View File

@ -9,3 +9,4 @@ from .history_monitor import HistoryMonitor, BashMonitor, GDBMonitor
from .terminal_commands import TerminalCommands
from .log_monitoring_event_handler import LogMonitoringEventHandler
from .fsm_managing_event_handler import FSMManagingEventHandler
from .snapshot_provider import SnapshotProvider

View File

@ -2,6 +2,7 @@
# All Rights Reserved. See LICENSE file for details.
from tfw import EventHandlerBase
from tfw.crypto import KeyManager, sign_message
from tfw.config.logs import logging
LOG = logging.getLogger(__name__)
@ -12,6 +13,7 @@ class FSMManagingEventHandler(EventHandlerBase):
super().__init__(key)
self.fsm = fsm_type()
self._fsm_updater = FSMUpdater(self.fsm)
self.auth_key = KeyManager().auth_key
self.command_handlers = {
'trigger': self.handle_trigger,
@ -22,16 +24,21 @@ class FSMManagingEventHandler(EventHandlerBase):
try:
data = message['data']
message['data'] = self.command_handlers[data['command']](data)
self.server_connector.broadcast(self._fsm_updater.generate_fsm_update())
fsm_update_message = self._fsm_updater.generate_fsm_update()
sign_message(self.auth_key, message)
sign_message(self.auth_key, fsm_update_message)
self.server_connector.broadcast(fsm_update_message)
return message
except KeyError:
LOG.error('IGNORING MESSAGE: Invalid message received: %s', message)
def handle_trigger(self, data):
self.fsm.step(data['value'])
trigger = data['value']
self.fsm.step(trigger)
return data
def handle_update(self, data):
# pylint: disable=no-self-use
return data
@ -51,7 +58,11 @@ class FSMUpdater:
{'trigger': trigger}
for trigger in self.fsm.get_triggers(self.fsm.state)
]
last_trigger = self.fsm.trigger_history[-1] if self.fsm.trigger_history else None
in_accepted_state = state in self.fsm.accepted_states
return {
'current_state': state,
'valid_transitions': valid_transitions
'valid_transitions': valid_transitions,
'last_trigger': last_trigger,
'in_accepted_state': in_accepted_state
}

View File

@ -4,7 +4,7 @@
from os.path import isfile, join, relpath, exists, isdir, realpath
from glob import glob
from fnmatch import fnmatchcase
from collections import Iterable
from typing import Iterable
from tfw import EventHandlerBase
from tfw.mixins import MonitorManagerMixin
@ -103,7 +103,7 @@ class FileManager: # pylint: disable=too-many-instance-attributes
class IdeEventHandler(EventHandlerBase, MonitorManagerMixin):
# pylint: disable=too-many-arguments
# pylint: disable=too-many-arguments,anomalous-backslash-in-string
"""
Event handler implementing the backend of our browser based IDE.
By default all files in the directory specified in __init__ are displayed

View File

@ -38,7 +38,6 @@ class ProcessManagingEventHandler(EventHandlerBase):
"""
def __init__(self, key, dirmonitor=None, log_tail=0):
super().__init__(key)
self.key = key
self.monitor = dirmonitor
self.processmanager = ProcessManager()
self.log_tail = log_tail

View File

@ -0,0 +1,175 @@
# Copyright (C) 2018 Avatao.com Innovative Learning Kft.
# All Rights Reserved. See LICENSE file for details.
import re
from subprocess import run, CalledProcessError
from getpass import getuser
from os.path import isdir
from datetime import datetime
from uuid import uuid4
class SnapshotProvider:
def __init__(self, directory, git_dir):
self._classname = self.__class__.__name__
author = f'{getuser()} via TFW {self._classname}'
self.gitenv = {
'GIT_DIR': git_dir,
'GIT_WORK_TREE': directory,
'GIT_AUTHOR_NAME': author,
'GIT_AUTHOR_EMAIL': '',
'GIT_COMMITTER_NAME': author,
'GIT_COMMITTER_EMAIL': '',
'GIT_PAGER': 'cat'
}
self._init_repo()
self.__last_valid_branch = self._branch
def _init_repo(self):
self._check_environment()
if not self._repo_is_initialized:
self._run(('git', 'init'))
if self._number_of_commits == 0:
try:
self._snapshot()
except CalledProcessError:
raise EnvironmentError(f'{self._classname} cannot init on empty directories!')
self._check_head_not_detached()
def _check_environment(self):
if not isdir(self.gitenv['GIT_DIR']) or not isdir(self.gitenv['GIT_WORK_TREE']):
raise EnvironmentError(f'{self._classname}: "directory" and "git_dir" must exist!')
@property
def _repo_is_initialized(self):
return self._run(
('git', 'status'),
check=False
).returncode == 0
@property
def _number_of_commits(self):
return int(
self._get_stdout((
'git', 'rev-list',
'--all',
'--count'
))
)
def _snapshot(self):
self._run((
'git', 'add',
'-A'
))
self._run((
'git', 'commit',
'-m', 'Snapshot'
))
def _check_head_not_detached(self):
if self._head_detached:
raise EnvironmentError(f'{self._classname} cannot init from detached HEAD state!')
@property
def _head_detached(self):
return self._branch == 'HEAD'
@property
def _branch(self):
return self._get_stdout((
'git', 'rev-parse',
'--abbrev-ref', 'HEAD'
))
def _get_stdout(self, *args, **kwargs):
kwargs['capture_output'] = True
stdout_bytes = self._run(*args, **kwargs).stdout
return stdout_bytes.decode().rstrip('\n')
def _run(self, *args, **kwargs):
if 'check' not in kwargs:
kwargs['check'] = True
if 'env' not in kwargs:
kwargs['env'] = self.gitenv
return run(*args, **kwargs)
def take_snapshot(self):
if self._head_detached:
self._checkout_new_branch_from_head()
self._snapshot()
def _checkout_new_branch_from_head(self):
branch_name = uuid4()
self._run((
'git', 'branch',
branch_name
))
self._checkout(branch_name)
def _checkout(self, what):
self._run((
'git', 'checkout',
what
))
def restore_snapshot(self, date):
commit = self._get_commit_from_timestamp(date)
self._checkout(commit)
def _get_commit_from_timestamp(self, date):
return self._get_stdout((
'git', 'rev-list',
'--date=iso',
'-n', '1',
f'--before="{date.isoformat()}"',
self._last_valid_branch
))
@property
def _last_valid_branch(self):
if not self._head_detached:
self.__last_valid_branch = self._branch
return self.__last_valid_branch
@property
def all_timelines(self):
return self._branches
@property
def _branches(self):
git_branch_output = self._get_stdout(('git', 'branch'))
regex_pattern = re.compile(r'(?:[^\S\n]|[*])') # matches '*' and non-newline whitespace chars
return re.sub(regex_pattern, '', git_branch_output).splitlines()
@property
def timeline(self):
return self._last_valid_branch
@timeline.setter
def timeline(self, value):
self._checkout(value)
@property
def snapshots(self):
return self._pretty_log_branch()
def _pretty_log_branch(self):
git_log_output = self._get_stdout((
'git', 'log',
'--pretty=%H@%aI'
))
commits = []
for line in git_log_output.splitlines():
commit_hash, timestamp = line.split('@')
commits.append({
'hash': commit_hash,
'timestamp': datetime.fromisoformat(timestamp)
})
return commits

View File

@ -11,6 +11,7 @@ LOG = logging.getLogger(__name__)
class TerminalCommands(ABC):
# pylint: disable=anomalous-backslash-in-string
"""
A class you can use to define hooks for terminal commands. This means that you can
have python code executed when the user enters a specific command to the terminal on

107
lib/tfw/crypto.py Normal file
View File

@ -0,0 +1,107 @@
# Copyright (C) 2018 Avatao.com Innovative Learning Kft.
# All Rights Reserved. See LICENSE file for details.
from functools import wraps
from base64 import b64encode, b64decode
from copy import deepcopy
from hashlib import md5
from os import urandom, chmod
from os.path import exists
from stat import S_IRUSR, S_IWUSR, S_IXUSR
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.hashes import SHA256
from cryptography.hazmat.primitives.hmac import HMAC as _HMAC
from cryptography.exceptions import InvalidSignature
from tfw.networking import message_bytes
from tfw.decorators import lazy_property
from tfw.config import TFWENV
def message_checksum(message):
return md5(message_bytes(message)).hexdigest()
def sign_message(key, message):
signature = message_signature(key, message)
message['signature'] = b64encode(signature).decode()
def message_signature(key, message):
return HMAC(key, message_bytes(message)).signature
def verify_message(key, message):
message = deepcopy(message)
try:
signature_b64 = message.pop('signature')
signature = b64decode(signature_b64)
actual_signature = message_signature(key, message)
return signature == actual_signature
except KeyError:
return False
class KeyManager:
def __init__(self):
self.keyfile = TFWENV.AUTH_KEY
if not exists(self.keyfile):
self._init_auth_key()
@lazy_property
def auth_key(self):
with open(self.keyfile, 'rb') as ifile:
return ifile.read()
def _init_auth_key(self):
key = self.generate_key()
with open(self.keyfile, 'wb') as ofile:
ofile.write(key)
self._chmod_700_keyfile()
return key
@staticmethod
def generate_key():
return urandom(32)
def _chmod_700_keyfile(self):
chmod(self.keyfile, S_IRUSR | S_IWUSR | S_IXUSR)
class HMAC:
def __init__(self, key, message):
self.key = key
self.message = message
self._hmac = _HMAC(
key=key,
algorithm=SHA256(),
backend=default_backend()
)
def _reload_if_finalized(f):
# pylint: disable=no-self-argument,not-callable
@wraps(f)
def wrapped(instance, *args, **kwargs):
if getattr(instance, '_finalized', False):
instance.__init__(instance.key, instance.message)
ret_val = f(instance, *args, **kwargs)
setattr(instance, '_finalized', True)
return ret_val
return wrapped
@property
@_reload_if_finalized
def signature(self):
self._hmac.update(self.message)
signature = self._hmac.finalize()
return signature
@_reload_if_finalized
def verify(self, signature):
self._hmac.update(self.message)
try:
self._hmac.verify(signature)
return True
except InvalidSignature:
return False

View File

@ -2,10 +2,9 @@
# All Rights Reserved. See LICENSE file for details.
from abc import ABC, abstractmethod
from json import dumps
from hashlib import md5
from tfw.networking.event_handlers import ServerConnector
from tfw.crypto import message_checksum, KeyManager, verify_message
from tfw.config.logs import logging
LOG = logging.getLogger(__name__)
@ -20,10 +19,17 @@ class EventHandlerBase(ABC):
"""
def __init__(self, key):
self.server_connector = ServerConnector()
self.key = key
self.subscribe(self.key, 'reset')
self.keys = [key]
self.subscribe(*self.keys)
self.server_connector.register_callback(self.event_handler_callback)
@property
def key(self):
"""
Returns the oldest key this EventHandler was subscribed to.
"""
return self.keys[0]
def event_handler_callback(self, message):
"""
Callback that is invoked when receiving a message.
@ -48,7 +54,7 @@ class EventHandlerBase(ABC):
subscribed to 'fsm' will receive 'fsm_update'
messages as well.
"""
return self.key == message['key']
return message['key'] in self.keys
def dispatch_handling(self, message):
"""
@ -57,9 +63,7 @@ class EventHandlerBase(ABC):
:param message: the message received
:returns: the message to send back
"""
if message['key'] != 'reset':
return self.handle_event(message)
return self.handle_reset(message)
return self.handle_event(message)
@abstractmethod
def handle_event(self, message):
@ -71,16 +75,6 @@ class EventHandlerBase(ABC):
"""
raise NotImplementedError
def handle_reset(self, message):
# pylint: disable=unused-argument,no-self-use
"""
Usually 'reset' events receive some sort of special treatment.
:param message: the message received
:returns: the message to send back
"""
return None
def subscribe(self, *keys):
"""
Subscribe this EventHandler to receive events for given keys.
@ -92,6 +86,7 @@ class EventHandlerBase(ABC):
"""
for key in keys:
self.server_connector.subscribe(key)
self.keys.append(key)
def unsubscribe(self, *keys):
"""
@ -101,6 +96,7 @@ class EventHandlerBase(ABC):
"""
for key in keys:
self.server_connector.unsubscribe(key)
self.keys.remove(key)
def cleanup(self):
"""
@ -110,20 +106,42 @@ class EventHandlerBase(ABC):
pass
class TriggeredEventHandler(EventHandlerBase, ABC):
class FSMAwareEventHandler(EventHandlerBase, ABC):
# pylint: disable=abstract-method
"""
Abstract base class for EventHandlers which are only triggered in case
TFWServer has successfully triggered an FSM step defined in __init__.
Abstract base class for EventHandlers which automatically
keep track of the state of the TFW FSM.
"""
def __init__(self, key, trigger):
def __init__(self, key):
super().__init__(key)
self.trigger = trigger
self.subscribe('fsm_update')
self.fsm_state = None
self.in_accepted_state = False
self._auth_key = KeyManager().auth_key
def dispatch_handling(self, message):
if message.get('trigger') == self.trigger:
return super().dispatch_handling(message)
return None
if message['key'] == 'fsm_update':
if verify_message(self._auth_key, message):
self._handle_fsm_update(message)
return None
return super().dispatch_handling(message)
def _handle_fsm_update(self, message):
try:
new_state = message['data']['current_state']
trigger = message['data']['last_trigger']
if self.fsm_state != new_state:
self.handle_fsm_step(self.fsm_state, new_state, trigger)
self.fsm_state = new_state
self.in_accepted_state = message['data']['in_accepted_state']
except KeyError:
LOG.error('Invalid fsm_update message received!')
def handle_fsm_step(self, from_state, to_state, trigger):
"""
Called in case the TFW FSM has stepped.
"""
pass
class BroadcastingEventHandler(EventHandlerBase, ABC):
@ -137,7 +155,7 @@ class BroadcastingEventHandler(EventHandlerBase, ABC):
self.own_message_hashes = []
def event_handler_callback(self, message):
message_hash = self.hash_message(message)
message_hash = message_checksum(message)
if message_hash in self.own_message_hashes:
self.own_message_hashes.remove(message_hash)
@ -145,10 +163,5 @@ class BroadcastingEventHandler(EventHandlerBase, ABC):
response = self.dispatch_handling(message)
if response:
self.own_message_hashes.append(self.hash_message(response))
self.own_message_hashes.append(message_checksum(response))
self.server_connector.broadcast(response)
@staticmethod
def hash_message(message):
message_bytes = dumps(message, sort_keys=True).encode()
return md5(message_bytes).hexdigest()

View File

@ -22,8 +22,9 @@ class FSMBase(Machine, CallbackMixin):
states, transitions = [], []
def __init__(self, initial=None, accepted_states=None):
self.accepted_states = accepted_states or [self.states[-1]]
self.accepted_states = accepted_states or [self.states[-1].name]
self.trigger_predicates = defaultdict(list)
self.trigger_history = []
Machine.__init__(
self,
@ -57,9 +58,9 @@ class FSMBase(Machine, CallbackMixin):
for predicate in self.trigger_predicates[trigger]
)
# TODO: think about what could we do when this prevents triggering
if all(predicate_results):
try:
self.trigger(trigger)
self.trigger_history.append(trigger)
except (AttributeError, MachineError):
LOG.debug('FSM failed to execute nonexistent trigger: "%s"', trigger)

View File

@ -1,6 +1,8 @@
# Copyright (C) 2018 Avatao.com Innovative Learning Kft.
# All Rights Reserved. See LICENSE file for details.
from transitions import State
from .fsm_base import FSMBase
@ -13,17 +15,17 @@ class LinearFSM(FSMBase):
(0) -- step_1 --> (1) -- step_2 --> (2) -- step_3 --> (3) ... and so on
"""
def __init__(self, number_of_steps):
self.states = list(map(str, range(number_of_steps)))
self.states = [State(name=str(index)) for index in range(number_of_steps)]
self.transitions = []
for index in self.states[:-1]:
for state in self.states[:-1]:
self.transitions.append({
'trigger': f'step_{int(index)+1}',
'source': index,
'dest': str(int(index)+1)
'trigger': f'step_{int(state.name)+1}',
'source': state.name,
'dest': str(int(state.name)+1)
})
self.transitions.append({
'trigger': 'step_next',
'source': index,
'dest': str(int(index)+1)
'source': state.name,
'dest': str(int(state.name)+1)
})
super(LinearFSM, self).__init__()

View File

@ -9,6 +9,7 @@ from tfw.decorators import lazy_property
class CallbackMixin:
@lazy_property
def _callbacks(self):
# pylint: disable=no-self-use
return []
def subscribe_callback(self, callback, *args, **kwargs):

View File

@ -9,6 +9,7 @@ from tfw.decorators import lazy_property
class ObserverMixin:
@lazy_property
def observer(self):
# pylint: disable=no-self-use
return Observer()
def watch(self):

View File

@ -13,6 +13,7 @@ from tfw.config import TFWENV
class SupervisorBaseMixin:
@lazy_property
def supervisor(self):
# pylint: disable=no-self-use
return xmlrpc.client.ServerProxy(TFWENV.SUPERVISOR_HTTP_URI).supervisor

View File

@ -1,9 +1,9 @@
# Copyright (C) 2018 Avatao.com Innovative Learning Kft.
# All Rights Reserved. See LICENSE file for details.
from .serialization import serialize_tfw_msg, deserialize_tfw_msg, with_deserialize_tfw_msg
from .serialization import serialize_tfw_msg, deserialize_tfw_msg
from .serialization import with_deserialize_tfw_msg, message_bytes
from .zmq_connector_base import ZMQConnectorBase
# from .controller_connector import ControllerConnector # TODO: readd once controller stuff is resolved
from .message_sender import MessageSender
from .event_handlers.server_connector import ServerUplinkConnector as TFWServerConnector
from .server.tfw_server import TFWServer

View File

@ -1,18 +0,0 @@
# Copyright (C) 2018 Avatao.com Innovative Learning Kft.
# All Rights Reserved. See LICENSE file for details.
import zmq
from zmq.eventloop.zmqstream import ZMQStream
from tfw.config import TFWENV
from tfw.networking import ZMQConnectorBase
class ControllerConnector(ZMQConnectorBase):
def __init__(self, zmq_context=None):
super(ControllerConnector, self).__init__(zmq_context)
self._zmq_rep_socket = self._zmq_context.socket(zmq.REP)
self._zmq_rep_socket.connect(f'tcp://localhost:{TFWENV.CONTROLLER_PORT}')
self._zmq_rep_stream = ZMQStream(self._zmq_rep_socket)
self.register_callback = self._zmq_rep_stream.on_recv_stream

View File

@ -44,3 +44,11 @@ class MessageSender:
'key': self.queue_key,
'data': data
})
@staticmethod
def generate_messages_from_queue(queue_message):
for message in queue_message['data']['messages']:
yield {
'key': 'message',
'data': message
}

View File

@ -67,10 +67,14 @@ def _serialize_single(data):
(serialize input if it is JSON)
"""
if not isinstance(data, str):
data = json.dumps(data)
data = message_bytes(data)
return _encode_if_needed(data)
def message_bytes(message):
return json.dumps(message, sort_keys=True).encode()
def _deserialize_single(data):
"""
Try parsing input as JSON, return it as

View File

@ -3,4 +3,3 @@
from .event_handler_connector import EventHandlerConnector, EventHandlerUplinkConnector, EventHandlerDownlinkConnector
from .tfw_server import TFWServer
# from .controller_responder import ControllerResponder # TODO: readd once controller stuff is resolved

View File

@ -1,38 +0,0 @@
# Copyright (C) 2018 Avatao.com Innovative Learning Kft.
# All Rights Reserved. See LICENSE file for details.
from tfw.networking import deserialize_all, serialize_all, ControllerConnector
class ControllerResponder:
def __init__(self, fsm):
self.fsm = fsm
self.token = None
self.controller_connector = ControllerConnector()
self.controller_connector.register_callback(self.handle_controller_request)
self.controller_request_handlers = {
'solution_check': self.handle_solution_check_request,
'test': self.handle_test_request,
'token': self.handle_token_request
}
def handle_controller_request(self, stream, msg_parts):
key, data = deserialize_all(*msg_parts)
response = self.controller_request_handlers[key](data)
stream.send_multipart(serialize_all(self.token, response))
def handle_test_request(self, data):
# pylint: disable=unused-argument,no-self-use
return 'OK'
def handle_token_request(self, data):
if self.token is None:
self.token = data
return {'token': self.token}
def handle_solution_check_request(self, data):
# pylint: disable=unused-argument
return {
'solved': self.fsm.is_solved(),
'message': 'solved' if self.fsm.is_solved() else 'not solved'
}

View File

@ -1,10 +1,13 @@
# Copyright (C) 2018 Avatao.com Innovative Learning Kft.
# All Rights Reserved. See LICENSE file for details.
from abc import ABC, abstractmethod
from tornado.web import Application
from tfw.networking.event_handlers import ServerUplinkConnector
from tfw.networking.server import EventHandlerConnector
from tfw.networking import MessageSender
from tfw.config.logs import logging
from .zmq_websocket_proxy import ZMQWebSocketProxy
@ -22,11 +25,13 @@ class TFWServer:
self._uplink_connector = ServerUplinkConnector()
self.application = Application([(
r'/ws', ZMQWebSocketProxy,{
r'/ws', ZMQWebSocketProxy, {
'event_handler_connector': self._event_handler_connector,
'message_handlers': [self.handle_trigger]
})]
)
'message_handlers': [self.handle_trigger, self.handle_recover],
'frontend_message_handlers': [self.save_frontend_messages]
})])
self._frontend_messages = FrontendMessageStorage()
def handle_trigger(self, message):
if 'trigger' in message:
@ -39,5 +44,52 @@ class TFWServer:
}
})
def handle_recover(self, message):
if message['key'] == 'recover':
self._frontend_messages.replay_messages(self._uplink_connector)
self._frontend_messages.clear()
def save_frontend_messages(self, message):
self._frontend_messages.save_message(message)
def listen(self, port):
self.application.listen(port)
class MessageStorage(ABC):
def __init__(self):
self.saved_messages = []
def save_message(self, message):
if self.filter_message(message):
self.saved_messages.extend(self.transform_message(message))
@abstractmethod
def filter_message(self, message):
raise NotImplementedError
def transform_message(self, message): # pylint: disable=no-self-use
yield message
def clear(self):
self.saved_messages.clear()
class FrontendMessageStorage(MessageStorage):
def filter_message(self, message):
key = message['key']
command = message.get('data', {}).get('command')
return (
key in ('message', 'dashboard', 'queueMessages')
or key == 'ide' and command in ('select', 'read')
)
def transform_message(self, message):
if message['key'] == 'queueMessages':
yield from MessageSender.generate_messages_from_queue(message)
else:
yield message
def replay_messages(self, connector):
for message in self.saved_messages:
connector.send(message)

View File

@ -12,11 +12,15 @@ LOG = logging.getLogger(__name__)
class ZMQWebSocketProxy(WebSocketHandler):
# pylint: disable=abstract-method
instances = set()
def initialize(self, **kwargs): # pylint: disable=arguments-differ
self._event_handler_connector = kwargs['event_handler_connector']
self._message_handlers = kwargs.get('message_handlers', [])
self._frontend_message_handlers = kwargs.get('frontend_message_handlers', [])
self._eventhandler_message_handlers = kwargs.get('eventhandler_message_handlers', [])
self._proxy_filters = kwargs.get('proxy_filters', [])
self.proxy_eventhandler_to_websocket = TFWProxy(
@ -28,10 +32,18 @@ class ZMQWebSocketProxy(WebSocketHandler):
self.send_eventhandler_message
)
proxies = (self.proxy_eventhandler_to_websocket, self.proxy_websocket_to_eventhandler)
for proxy in proxies:
proxy.proxy_filters.subscribe_callbacks(*self._proxy_filters)
proxy.proxy_callbacks.subscribe_callbacks(*self._message_handlers)
self.subscribe_proxy_callbacks()
def subscribe_proxy_callbacks(self):
self.proxy_websocket_to_eventhandler.subscribe_proxy_callbacks_and_filters(
self._eventhandler_message_handlers + self._message_handlers,
self._proxy_filters
)
self.proxy_eventhandler_to_websocket.subscribe_proxy_callbacks_and_filters(
self._frontend_message_handlers + self._message_handlers,
self._proxy_filters
)
def prepare(self):
ZMQWebSocketProxy.instances.add(self)
@ -72,6 +84,7 @@ class ZMQWebSocketProxy(WebSocketHandler):
class TFWProxy:
# pylint: disable=protected-access
def __init__(self, to_source, to_destination):
self.to_source = to_source
self.to_destination = to_destination
@ -119,3 +132,7 @@ class TFWProxy:
LOG.debug('Broadcasting message: %s', message)
self.to_source(message)
self.to_destination(message)
def subscribe_proxy_callbacks_and_filters(self, proxy_callbacks, proxy_filters):
self.proxy_callbacks.subscribe_callbacks(*proxy_callbacks)
self.proxy_filters.subscribe_callbacks(*proxy_filters)

View File

@ -40,19 +40,19 @@ class YamlFSM(FSMBase):
def subscribe_and_remove_predicates(self, json_obj):
if 'predicates' in json_obj:
for predicate in json_obj['predicates']:
self.subscribe_predicate(
json_obj['trigger'],
partial(
command_statuscode_is_zero,
predicate
self.subscribe_predicate(
json_obj['trigger'],
partial(
command_statuscode_is_zero,
predicate
)
)
)
with suppress(KeyError):
json_obj.pop('predicates')
def run_command_async(command, event):
def run_command_async(command, _):
Popen(command, shell=True)
@ -62,7 +62,7 @@ def command_statuscode_is_zero(command):
class ConfigParser:
def __init__(self, config_file, jinja2_variables):
self.read_variables = singledispatch(self.read_variables)
self.read_variables = singledispatch(self._read_variables)
self.read_variables.register(dict, self._read_variables_dict)
self.read_variables.register(str, self._read_variables_str)
@ -82,16 +82,14 @@ class ConfigParser:
return ifile.read()
@staticmethod
def read_variables(variables):
def _read_variables(variables):
raise TypeError(f'Invalid variables type {type(variables)}')
@staticmethod
def _read_variables_str(variables):
if isinstance(variables, str):
with open(variables, 'r') as ifile:
return yaml.safe_load(ifile)
with open(variables, 'r') as ifile:
return yaml.safe_load(ifile)
@staticmethod
def _read_variables_dict(variables):
return variables
return variables

View File

@ -5,3 +5,4 @@ terminado==0.8.1
watchdog==0.8.3
PyYAML==3.12
Jinja2==2.10
cryptography==2.2.2