mirror of
				https://github.com/avatao-content/baseimage-tutorial-framework
				synced 2025-11-04 01:12:55 +00:00 
			
		
		
		
	Simplify package structure
This commit is contained in:
		
							
								
								
									
										0
									
								
								tfw/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								tfw/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										0
									
								
								tfw/components/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								tfw/components/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										2
									
								
								tfw/components/frontend/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								tfw/components/frontend/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,2 @@
 | 
			
		||||
from .frontend_handler import FrontendHandler
 | 
			
		||||
from .message_sender import MessageSender
 | 
			
		||||
							
								
								
									
										25
									
								
								tfw/components/frontend/frontend_handler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								tfw/components/frontend/frontend_handler.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,25 @@
 | 
			
		||||
from tfw.internals.networking import Scope
 | 
			
		||||
 | 
			
		||||
from .message_storage import FrontendMessageStorage
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FrontendHandler:
 | 
			
		||||
    keys = ['message', 'queueMessages', 'dashboard', 'console']
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.server_connector = None
 | 
			
		||||
        self.keys = [*type(self).keys, 'recover']
 | 
			
		||||
        self._frontend_message_storage = FrontendMessageStorage(type(self).keys)
 | 
			
		||||
 | 
			
		||||
    def send_message(self, message):
 | 
			
		||||
        self.server_connector.send_message(message, scope=Scope.WEBSOCKET)
 | 
			
		||||
 | 
			
		||||
    def handle_event(self, message, _):
 | 
			
		||||
        self._frontend_message_storage.save_message(message)
 | 
			
		||||
        if message['key'] == 'recover':
 | 
			
		||||
            self.recover_frontend()
 | 
			
		||||
        self.send_message(message)
 | 
			
		||||
 | 
			
		||||
    def recover_frontend(self):
 | 
			
		||||
        for message in self._frontend_message_storage.messages:
 | 
			
		||||
            self.send_message(message)
 | 
			
		||||
							
								
								
									
										48
									
								
								tfw/components/frontend/message_sender.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								tfw/components/frontend/message_sender.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,48 @@
 | 
			
		||||
class MessageSender:
 | 
			
		||||
    """
 | 
			
		||||
    Provides mechanisms to send messages to our frontend messaging component.
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, uplink):
 | 
			
		||||
        self.uplink = uplink
 | 
			
		||||
        self.key = 'message'
 | 
			
		||||
        self.queue_key = 'queueMessages'
 | 
			
		||||
 | 
			
		||||
    def send(self, originator, message):
 | 
			
		||||
        """
 | 
			
		||||
        Sends a message.
 | 
			
		||||
        :param originator: name of sender to be displayed on the frontend
 | 
			
		||||
        :param message: message to send
 | 
			
		||||
        """
 | 
			
		||||
        message = {
 | 
			
		||||
            'key': self.key,
 | 
			
		||||
            'data': {
 | 
			
		||||
                'originator': originator,
 | 
			
		||||
                'message': message
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        self.uplink.send_message(message)
 | 
			
		||||
 | 
			
		||||
    def queue_messages(self, originator, messages):
 | 
			
		||||
        """
 | 
			
		||||
        Queues a list of messages to be displayed in a chatbot-like manner.
 | 
			
		||||
        :param originator: name of sender to be displayed on the frontend
 | 
			
		||||
        :param messages: list of messages to queue
 | 
			
		||||
        """
 | 
			
		||||
        message = {
 | 
			
		||||
            'key': self.queue_key,
 | 
			
		||||
            'data': {
 | 
			
		||||
                'messages': [
 | 
			
		||||
                    {'message': message, 'originator': originator}
 | 
			
		||||
                    for message in messages
 | 
			
		||||
                ]
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        self.uplink.send_message(message)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def generate_messages_from_queue(queue_message):
 | 
			
		||||
        for message in queue_message['data']['messages']:
 | 
			
		||||
            yield {
 | 
			
		||||
                'key': 'message',
 | 
			
		||||
                'data': message
 | 
			
		||||
            }
 | 
			
		||||
							
								
								
									
										44
									
								
								tfw/components/frontend/message_storage.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								tfw/components/frontend/message_storage.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,44 @@
 | 
			
		||||
from abc import ABC, abstractmethod
 | 
			
		||||
from contextlib import suppress
 | 
			
		||||
 | 
			
		||||
from .message_sender import MessageSender
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MessageStorage(ABC):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self._messages = []
 | 
			
		||||
 | 
			
		||||
    def save_message(self, message):
 | 
			
		||||
        with suppress(KeyError, AttributeError):
 | 
			
		||||
            if self._filter_message(message):
 | 
			
		||||
                self._messages.extend(self._transform_message(message))
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def _filter_message(self, message):
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    def _transform_message(self, message):   # pylint: disable=no-self-use
 | 
			
		||||
        yield message
 | 
			
		||||
 | 
			
		||||
    def clear(self):
 | 
			
		||||
        self._messages.clear()
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def messages(self):
 | 
			
		||||
        yield from self._messages
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FrontendMessageStorage(MessageStorage):
 | 
			
		||||
    def __init__(self, keys):
 | 
			
		||||
        self._keys = keys
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    def _filter_message(self, message):
 | 
			
		||||
        key = message['key']
 | 
			
		||||
        return key in self._keys
 | 
			
		||||
 | 
			
		||||
    def _transform_message(self, message):
 | 
			
		||||
        if message['key'] == 'queueMessages':
 | 
			
		||||
            yield from MessageSender.generate_messages_from_queue(message)
 | 
			
		||||
        else:
 | 
			
		||||
            yield message
 | 
			
		||||
							
								
								
									
										1
									
								
								tfw/components/fsm/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								tfw/components/fsm/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1 @@
 | 
			
		||||
from .fsm_handler import FSMHandler
 | 
			
		||||
							
								
								
									
										71
									
								
								tfw/components/fsm/fsm_handler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										71
									
								
								tfw/components/fsm/fsm_handler.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,71 @@
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
from tfw.internals.crypto import KeyManager, sign_message, verify_message
 | 
			
		||||
from tfw.internals.networking import Scope
 | 
			
		||||
 | 
			
		||||
from .fsm_updater import FSMUpdater
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
LOG = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FSMHandler:
 | 
			
		||||
    keys = ['fsm']
 | 
			
		||||
    """
 | 
			
		||||
    EventHandler responsible for managing the state machine of
 | 
			
		||||
    the framework (TFW FSM).
 | 
			
		||||
 | 
			
		||||
    tfw.networking.TFWServer instances automatically send 'trigger'
 | 
			
		||||
    commands to the event handler listening on the 'fsm' key,
 | 
			
		||||
    which should be an instance of this event handler.
 | 
			
		||||
 | 
			
		||||
    This event handler accepts messages that have a
 | 
			
		||||
    data['command'] key specifying a command to be executed.
 | 
			
		||||
 | 
			
		||||
    An 'fsm_update' message is broadcasted after every successful
 | 
			
		||||
    command.
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, *, fsm_type, require_signature=False):
 | 
			
		||||
        self.fsm = fsm_type()
 | 
			
		||||
        self._fsm_updater = FSMUpdater(self.fsm)
 | 
			
		||||
        self.auth_key = KeyManager().auth_key
 | 
			
		||||
        self._require_signature = require_signature
 | 
			
		||||
 | 
			
		||||
        self.command_handlers = {
 | 
			
		||||
            'trigger': self.handle_trigger,
 | 
			
		||||
            'update':  self.handle_update
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    def handle_event(self, message, server_connector):
 | 
			
		||||
        try:
 | 
			
		||||
            message = self.command_handlers[message['data']['command']](message)
 | 
			
		||||
            if message:
 | 
			
		||||
                fsm_update_message = self._fsm_updater.fsm_update
 | 
			
		||||
                sign_message(self.auth_key, message)
 | 
			
		||||
                sign_message(self.auth_key, fsm_update_message)
 | 
			
		||||
                server_connector.send_message(fsm_update_message, Scope.BROADCAST)
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            LOG.error('IGNORING MESSAGE: Invalid message received: %s', message)
 | 
			
		||||
 | 
			
		||||
    def handle_trigger(self, message):
 | 
			
		||||
        """
 | 
			
		||||
        Attempts to step the FSM with the supplied trigger.
 | 
			
		||||
 | 
			
		||||
        :param message: TFW message with a data field containing
 | 
			
		||||
                        the action to try triggering in data['value']
 | 
			
		||||
        """
 | 
			
		||||
        trigger = message['data']['value']
 | 
			
		||||
        if self._require_signature:
 | 
			
		||||
            if not verify_message(self.auth_key, message):
 | 
			
		||||
                LOG.error('Ignoring unsigned trigger command: %s', message)
 | 
			
		||||
                return None
 | 
			
		||||
        if self.fsm.step(trigger):
 | 
			
		||||
            return message
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    def handle_update(self, message):
 | 
			
		||||
        """
 | 
			
		||||
        Does nothing, but triggers an 'fsm_update' message.
 | 
			
		||||
        """
 | 
			
		||||
        # pylint: disable=no-self-use
 | 
			
		||||
        return message
 | 
			
		||||
							
								
								
									
										25
									
								
								tfw/components/fsm/fsm_updater.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								tfw/components/fsm/fsm_updater.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,25 @@
 | 
			
		||||
class FSMUpdater:
 | 
			
		||||
    def __init__(self, fsm):
 | 
			
		||||
        self.fsm = fsm
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def fsm_update(self):
 | 
			
		||||
        return {
 | 
			
		||||
            'key': 'fsm_update',
 | 
			
		||||
            **self.fsm_update_data
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def fsm_update_data(self):
 | 
			
		||||
        valid_transitions = [
 | 
			
		||||
            {'trigger': trigger}
 | 
			
		||||
            for trigger in self.fsm.get_triggers(self.fsm.state)
 | 
			
		||||
        ]
 | 
			
		||||
        last_fsm_event = self.fsm.event_log[-1]
 | 
			
		||||
        last_fsm_event['timestamp'] = last_fsm_event['timestamp'].isoformat()
 | 
			
		||||
        return {
 | 
			
		||||
            'current_state': self.fsm.state,
 | 
			
		||||
            'valid_transitions': valid_transitions,
 | 
			
		||||
            'in_accepted_state': self.fsm.in_accepted_state,
 | 
			
		||||
            'last_event': last_fsm_event
 | 
			
		||||
        }
 | 
			
		||||
							
								
								
									
										1
									
								
								tfw/components/ide/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								tfw/components/ide/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1 @@
 | 
			
		||||
from .ide_handler import IdeHandler
 | 
			
		||||
							
								
								
									
										1
									
								
								tfw/components/ide/file_manager/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								tfw/components/ide/file_manager/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1 @@
 | 
			
		||||
from .file_manager import FileManager
 | 
			
		||||
							
								
								
									
										93
									
								
								tfw/components/ide/file_manager/file_manager.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										93
									
								
								tfw/components/ide/file_manager/file_manager.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,93 @@
 | 
			
		||||
from typing import Iterable
 | 
			
		||||
from glob import glob
 | 
			
		||||
from fnmatch import fnmatchcase
 | 
			
		||||
from os.path import basename, isfile, join, relpath, exists, isdir, realpath
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FileManager:  # pylint: disable=too-many-instance-attributes
 | 
			
		||||
    def __init__(self, working_directory, allowed_directories, selected_file=None, exclude=None):
 | 
			
		||||
        self._exclude, self.exclude = [], exclude
 | 
			
		||||
        self._allowed_directories, self.allowed_directories = None, allowed_directories
 | 
			
		||||
        self._workdir, self.workdir = None, working_directory
 | 
			
		||||
        self._filename, self.filename = None, selected_file or self.files[0]
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def exclude(self):
 | 
			
		||||
        return self._exclude
 | 
			
		||||
 | 
			
		||||
    @exclude.setter
 | 
			
		||||
    def exclude(self, exclude):
 | 
			
		||||
        if exclude is None:
 | 
			
		||||
            return
 | 
			
		||||
        if not isinstance(exclude, Iterable):
 | 
			
		||||
            raise TypeError('Exclude must be Iterable!')
 | 
			
		||||
        self._exclude = exclude
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def workdir(self):
 | 
			
		||||
        return self._workdir
 | 
			
		||||
 | 
			
		||||
    @workdir.setter
 | 
			
		||||
    def workdir(self, directory):
 | 
			
		||||
        if not exists(directory) or not isdir(directory):
 | 
			
		||||
            raise EnvironmentError(f'"{directory}" is not a directory!')
 | 
			
		||||
        if not self._is_in_allowed_dir(directory):
 | 
			
		||||
            raise EnvironmentError(f'Directory "{directory}" is not allowed!')
 | 
			
		||||
        self._workdir = directory
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def allowed_directories(self):
 | 
			
		||||
        return self._allowed_directories
 | 
			
		||||
 | 
			
		||||
    @allowed_directories.setter
 | 
			
		||||
    def allowed_directories(self, directories):
 | 
			
		||||
        self._allowed_directories = [realpath(directory) for directory in directories]
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def filename(self):
 | 
			
		||||
        return self._filename
 | 
			
		||||
 | 
			
		||||
    @filename.setter
 | 
			
		||||
    def filename(self, filename):
 | 
			
		||||
        if filename not in self.files:
 | 
			
		||||
            raise EnvironmentError('No such file in workdir!')
 | 
			
		||||
        self._filename = filename
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def files(self):
 | 
			
		||||
        return [
 | 
			
		||||
            self._relpath(file)
 | 
			
		||||
            for file in glob(join(self._workdir, '**/*'), recursive=True)
 | 
			
		||||
            if isfile(file)
 | 
			
		||||
            and self._is_in_allowed_dir(file)
 | 
			
		||||
            and not self._is_blacklisted(file)
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def file_contents(self):
 | 
			
		||||
        with open(self._filepath(self.filename), 'rb', buffering=0) as ifile:
 | 
			
		||||
            return ifile.read().decode(errors='surrogateescape')
 | 
			
		||||
 | 
			
		||||
    @file_contents.setter
 | 
			
		||||
    def file_contents(self, value):
 | 
			
		||||
        with open(self._filepath(self.filename), 'wb', buffering=0) as ofile:
 | 
			
		||||
            ofile.write(value.encode())
 | 
			
		||||
 | 
			
		||||
    def _is_in_allowed_dir(self, path):
 | 
			
		||||
        return any(
 | 
			
		||||
            realpath(path).startswith(allowed_dir)
 | 
			
		||||
            for allowed_dir in self.allowed_directories
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _is_blacklisted(self, file):
 | 
			
		||||
        return any(
 | 
			
		||||
            fnmatchcase(file, blacklisted) or
 | 
			
		||||
            fnmatchcase(basename(file), blacklisted)
 | 
			
		||||
            for blacklisted in self.exclude
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _filepath(self, filename):
 | 
			
		||||
        return join(self._workdir, filename)
 | 
			
		||||
 | 
			
		||||
    def _relpath(self, filename):
 | 
			
		||||
        return relpath(self._filepath(filename), start=self._workdir)
 | 
			
		||||
							
								
								
									
										124
									
								
								tfw/components/ide/file_manager/test_file_manager.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										124
									
								
								tfw/components/ide/file_manager/test_file_manager.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,124 @@
 | 
			
		||||
# pylint: disable=redefined-outer-name
 | 
			
		||||
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from secrets import token_urlsafe
 | 
			
		||||
from os.path import join
 | 
			
		||||
from os import chdir, mkdir, symlink
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from tempfile import TemporaryDirectory
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
from .file_manager import FileManager
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class ManagerContext:
 | 
			
		||||
    folder: str
 | 
			
		||||
    manager: FileManager
 | 
			
		||||
 | 
			
		||||
    def join(self, path):
 | 
			
		||||
        return join(self.folder, path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture()
 | 
			
		||||
def context():
 | 
			
		||||
    dirs = {}
 | 
			
		||||
 | 
			
		||||
    with TemporaryDirectory() as workdir:
 | 
			
		||||
        chdir(workdir)
 | 
			
		||||
        for name in ['allowed', 'excluded', 'invis']:
 | 
			
		||||
            node = join(workdir, name)
 | 
			
		||||
            mkdir(node)
 | 
			
		||||
            Path(join(node, 'empty.txt')).touch()
 | 
			
		||||
            Path(join(node, 'empty.bin')).touch()
 | 
			
		||||
            dirs[name] = node
 | 
			
		||||
 | 
			
		||||
        yield ManagerContext(
 | 
			
		||||
            workdir,
 | 
			
		||||
            FileManager(
 | 
			
		||||
                dirs['allowed'],
 | 
			
		||||
                [dirs['allowed'], dirs['excluded']],
 | 
			
		||||
                exclude=['*/excluded/*']
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize('subdir', ['allowed/', 'excluded/'])
 | 
			
		||||
def test_select_allowed_dirs(context, subdir):
 | 
			
		||||
    context.manager.workdir = context.join(subdir)
 | 
			
		||||
    assert context.manager.workdir == context.join(subdir)
 | 
			
		||||
    newdir = context.join(subdir+'deep')
 | 
			
		||||
    mkdir(newdir)
 | 
			
		||||
    context.manager.workdir = newdir
 | 
			
		||||
    assert context.manager.workdir == newdir
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize('invdir', ['', 'invis'])
 | 
			
		||||
def test_select_forbidden_dirs(context, invdir):
 | 
			
		||||
    fullpath = context.join(invdir)
 | 
			
		||||
    with pytest.raises(OSError):
 | 
			
		||||
        context.manager.workdir = fullpath
 | 
			
		||||
    assert context.manager.workdir != fullpath
 | 
			
		||||
    context.manager.allowed_directories += [fullpath]
 | 
			
		||||
    context.manager.workdir = fullpath
 | 
			
		||||
    assert context.manager.workdir == fullpath
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize('filename', ['another.txt', '*.txt'])
 | 
			
		||||
def test_select_allowed_files(context, filename):
 | 
			
		||||
    Path(context.join('allowed/'+filename)).touch()
 | 
			
		||||
    assert filename in context.manager.files
 | 
			
		||||
    context.manager.filename = filename
 | 
			
		||||
    assert context.manager.filename == filename
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize('path', [
 | 
			
		||||
    {'dir': 'allowed/', 'file': 'illegal.bin'},
 | 
			
		||||
    {'dir': 'excluded/', 'file': 'legal.txt'},
 | 
			
		||||
    {'dir': 'allowed/', 'file': token_urlsafe(16)+'.bin'},
 | 
			
		||||
    {'dir': 'excluded/', 'file': token_urlsafe(16)+'.txt'},
 | 
			
		||||
    {'dir': 'allowed/', 'file': token_urlsafe(32)+'.bin'},
 | 
			
		||||
    {'dir': 'excluded/', 'file': token_urlsafe(32)+'.txt'}
 | 
			
		||||
])
 | 
			
		||||
def test_select_excluded_files(context, path):
 | 
			
		||||
    context.manager.workdir = context.join(path['dir'])
 | 
			
		||||
    context.manager.exclude = ['*/excluded/*', '*.bin']
 | 
			
		||||
    Path(context.join(path['dir']+path['file'])).touch()
 | 
			
		||||
    assert path['file'] not in context.manager.files
 | 
			
		||||
    with pytest.raises(OSError):
 | 
			
		||||
        context.manager.filename = path['file']
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize('path', [
 | 
			
		||||
    {'src': 'excluded/empty.txt', 'dst': 'allowed/link.txt'},
 | 
			
		||||
    {'src': 'invis/empty.txt', 'dst': 'allowed/link.txt'},
 | 
			
		||||
    {'src': 'excluded/empty.txt', 'dst': 'allowed/'+token_urlsafe(16)+'.txt'},
 | 
			
		||||
    {'src': 'invis/empty.txt', 'dst': 'allowed/'+token_urlsafe(16)+'.txt'},
 | 
			
		||||
    {'src': 'excluded/empty.txt', 'dst': 'allowed/'+token_urlsafe(32)+'.txt'},
 | 
			
		||||
    {'src': 'invis/empty.txt', 'dst': 'allowed/'+token_urlsafe(32)+'.txt'}
 | 
			
		||||
])
 | 
			
		||||
def test_select_excluded_symlinks(context, path):
 | 
			
		||||
    symlink(context.join(path['src']), context.join(path['dst']))
 | 
			
		||||
    assert path['dst'] not in context.manager.files
 | 
			
		||||
 | 
			
		||||
def test_read_write_file(context):
 | 
			
		||||
    for _ in range(128):
 | 
			
		||||
        context.manager.filename = 'empty.txt'
 | 
			
		||||
        content = token_urlsafe(32)
 | 
			
		||||
        context.manager.file_contents = content
 | 
			
		||||
        assert context.manager.file_contents == content
 | 
			
		||||
        with open(context.join('allowed/empty.txt'), 'r') as ifile:
 | 
			
		||||
            assert ifile.read() == content
 | 
			
		||||
 | 
			
		||||
def test_regular_ide_actions(context):
 | 
			
		||||
    context.manager.workdir = context.join('allowed')
 | 
			
		||||
    newfile1, newfile2 = token_urlsafe(16), token_urlsafe(16)
 | 
			
		||||
    Path(context.join(f'allowed/{newfile1}')).touch()
 | 
			
		||||
    Path(context.join(f'allowed/{newfile2}')).touch()
 | 
			
		||||
    for _ in range(8):
 | 
			
		||||
        context.manager.filename = newfile1
 | 
			
		||||
        content1 = token_urlsafe(32)
 | 
			
		||||
        context.manager.file_contents = content1
 | 
			
		||||
        context.manager.filename = newfile2
 | 
			
		||||
        content2 = token_urlsafe(32)
 | 
			
		||||
        context.manager.file_contents = content2
 | 
			
		||||
        context.manager.filename = newfile1
 | 
			
		||||
        assert context.manager.file_contents == content1
 | 
			
		||||
        context.manager.filename = newfile2
 | 
			
		||||
        assert context.manager.file_contents == content2
 | 
			
		||||
							
								
								
									
										196
									
								
								tfw/components/ide/ide_handler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										196
									
								
								tfw/components/ide/ide_handler.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,196 @@
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
from tfw.internals.networking import Scope
 | 
			
		||||
from tfw.internals.inotify import InotifyObserver
 | 
			
		||||
 | 
			
		||||
from .file_manager import FileManager
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
LOG = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
BUILD_ARTIFACTS = (
 | 
			
		||||
    "*.a",
 | 
			
		||||
    "*.class",
 | 
			
		||||
    "*.dll",
 | 
			
		||||
    "*.dylib",
 | 
			
		||||
    "*.elf",
 | 
			
		||||
    "*.exe",
 | 
			
		||||
    "*.jar",
 | 
			
		||||
    "*.ko",
 | 
			
		||||
    "*.la",
 | 
			
		||||
    "*.lib",
 | 
			
		||||
    "*.lo",
 | 
			
		||||
    "*.o",
 | 
			
		||||
    "*.obj",
 | 
			
		||||
    "*.out",
 | 
			
		||||
    "*.py[cod]",
 | 
			
		||||
    "*.so",
 | 
			
		||||
    "*.so.*",
 | 
			
		||||
    "*.tar.gz",
 | 
			
		||||
    "*.zip",
 | 
			
		||||
    "*__pycache__*"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class IdeHandler:
 | 
			
		||||
    keys = ['ide']
 | 
			
		||||
    # pylint: disable=too-many-arguments,anomalous-backslash-in-string
 | 
			
		||||
    """
 | 
			
		||||
    Event handler implementing the backend of our browser based IDE.
 | 
			
		||||
    By default all files in the directory specified in __init__ are displayed
 | 
			
		||||
    on the fontend. Note that this is a stateful component.
 | 
			
		||||
 | 
			
		||||
    When any file in the selected directory changes they are automatically refreshed
 | 
			
		||||
    on the frontend (this is done by listening to inotify events).
 | 
			
		||||
 | 
			
		||||
    This EventHandler accepts messages that have a data['command'] key specifying
 | 
			
		||||
    a command to be executed.
 | 
			
		||||
 | 
			
		||||
    The API of each command is documented in their respective handler.
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, *, directory, allowed_directories, selected_file=None, exclude=None):
 | 
			
		||||
        """
 | 
			
		||||
        :param key: the key this instance should listen to
 | 
			
		||||
        :param directory: working directory which the EventHandler should serve files from
 | 
			
		||||
        :param allowed_directories: list of directories that can be switched to using selectdir
 | 
			
		||||
        :param selected_file: file that is selected by default
 | 
			
		||||
        :param exclude: list of filenames that should not appear between files (for .o, .pyc, etc.)
 | 
			
		||||
        """
 | 
			
		||||
        self.server_connector = None
 | 
			
		||||
        try:
 | 
			
		||||
            self.filemanager = FileManager(
 | 
			
		||||
                allowed_directories=allowed_directories,
 | 
			
		||||
                working_directory=directory,
 | 
			
		||||
                selected_file=selected_file,
 | 
			
		||||
                exclude=exclude
 | 
			
		||||
            )
 | 
			
		||||
        except IndexError:
 | 
			
		||||
            raise EnvironmentError(
 | 
			
		||||
                f'No file(s) in IdeEventHandler working_directory "{directory}"!'
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        self.monitor = InotifyObserver(
 | 
			
		||||
            self.filemanager.allowed_directories,
 | 
			
		||||
            exclude=BUILD_ARTIFACTS
 | 
			
		||||
        )
 | 
			
		||||
        self.monitor.on_modified = self._reload_frontend
 | 
			
		||||
        self.monitor.start()
 | 
			
		||||
 | 
			
		||||
        self.commands = {
 | 
			
		||||
            'read':      self.read,
 | 
			
		||||
            'write':     self.write,
 | 
			
		||||
            'select':    self.select,
 | 
			
		||||
            'selectdir': self.select_dir,
 | 
			
		||||
            'exclude':   self.exclude
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    def _reload_frontend(self, event):  # pylint: disable=unused-argument
 | 
			
		||||
        self.send_message({
 | 
			
		||||
            'key': 'ide',
 | 
			
		||||
            'data': {'command': 'reload'}
 | 
			
		||||
        })
 | 
			
		||||
 | 
			
		||||
    def send_message(self, message):
 | 
			
		||||
        self.server_connector.send_message(message, scope=Scope.WEBSOCKET)
 | 
			
		||||
 | 
			
		||||
    def read(self, data):
 | 
			
		||||
        """
 | 
			
		||||
        Read the currently selected file.
 | 
			
		||||
 | 
			
		||||
        :return dict: TFW message data containing key 'content'
 | 
			
		||||
                      (contents of the selected file)
 | 
			
		||||
        """
 | 
			
		||||
        try:
 | 
			
		||||
            data['content'] = self.filemanager.file_contents
 | 
			
		||||
        except PermissionError:
 | 
			
		||||
            data['content'] = 'You have no permission to open that file :('
 | 
			
		||||
        except FileNotFoundError:
 | 
			
		||||
            data['content'] = 'This file was removed :('
 | 
			
		||||
        except Exception: # pylint: disable=broad-except
 | 
			
		||||
            data['content'] = 'Failed to read file :('
 | 
			
		||||
        return data
 | 
			
		||||
 | 
			
		||||
    def write(self, data):
 | 
			
		||||
        """
 | 
			
		||||
        Overwrites a file with the desired string.
 | 
			
		||||
 | 
			
		||||
        :param data: TFW message data containing key 'content'
 | 
			
		||||
                     (new file content)
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        try:
 | 
			
		||||
            self.filemanager.file_contents = data['content']
 | 
			
		||||
        except Exception: # pylint: disable=broad-except
 | 
			
		||||
            LOG.exception('Error writing file!')
 | 
			
		||||
        del data['content']
 | 
			
		||||
        return data
 | 
			
		||||
 | 
			
		||||
    def select(self, data):
 | 
			
		||||
        """
 | 
			
		||||
        Selects a file from the current directory.
 | 
			
		||||
 | 
			
		||||
        :param data: TFW message data containing 'filename'
 | 
			
		||||
                     (name of file to select relative to the current directory)
 | 
			
		||||
        """
 | 
			
		||||
        try:
 | 
			
		||||
            self.filemanager.filename = data['filename']
 | 
			
		||||
        except EnvironmentError:
 | 
			
		||||
            LOG.exception('Failed to select file "%s"', data['filename'])
 | 
			
		||||
        return data
 | 
			
		||||
 | 
			
		||||
    def select_dir(self, data):
 | 
			
		||||
        """
 | 
			
		||||
        Select a new working directory to display files from.
 | 
			
		||||
 | 
			
		||||
        :param data: TFW message data containing 'directory'
 | 
			
		||||
                     (absolute path of diretory to select.
 | 
			
		||||
                     must be a path whitelisted in
 | 
			
		||||
                     self.allowed_directories)
 | 
			
		||||
        """
 | 
			
		||||
        try:
 | 
			
		||||
            self.filemanager.workdir = data['directory']
 | 
			
		||||
            try:
 | 
			
		||||
                self.filemanager.filename = self.filemanager.files[0]
 | 
			
		||||
                self.read(data)
 | 
			
		||||
            except IndexError:
 | 
			
		||||
                data['content'] = 'No files in this directory :('
 | 
			
		||||
        except EnvironmentError as err:
 | 
			
		||||
            LOG.error(
 | 
			
		||||
                'Failed to select directory "%s". Reason: %s',
 | 
			
		||||
                data['directory'], str(err)
 | 
			
		||||
            )
 | 
			
		||||
        return data
 | 
			
		||||
 | 
			
		||||
    def exclude(self, data):
 | 
			
		||||
        """
 | 
			
		||||
        Overwrite list of excluded files
 | 
			
		||||
 | 
			
		||||
        :param data: TFW message data containing 'exclude'
 | 
			
		||||
                     (list of unix-style filename patterns to be excluded,
 | 
			
		||||
                     e.g.: ["\*.pyc", "\*.o")
 | 
			
		||||
        """
 | 
			
		||||
        try:
 | 
			
		||||
            self.filemanager.exclude = list(data['exclude'])
 | 
			
		||||
        except TypeError:
 | 
			
		||||
            LOG.error('Exclude must be Iterable!')
 | 
			
		||||
        return data
 | 
			
		||||
 | 
			
		||||
    def attach_fileinfo(self, data):
 | 
			
		||||
        """
 | 
			
		||||
        Basic information included in every response to the frontend.
 | 
			
		||||
        """
 | 
			
		||||
        data['filename'] = self.filemanager.filename
 | 
			
		||||
        data['files'] = self.filemanager.files
 | 
			
		||||
        data['directory'] = self.filemanager.workdir
 | 
			
		||||
 | 
			
		||||
    def handle_event(self, message, _):
 | 
			
		||||
        try:
 | 
			
		||||
            data = message['data']
 | 
			
		||||
            message['data'] = self.commands[data['command']](data)
 | 
			
		||||
            self.attach_fileinfo(data)
 | 
			
		||||
            self.send_message(message)
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            LOG.error('IGNORING MESSAGE: Invalid message received: %s', message)
 | 
			
		||||
 | 
			
		||||
    def cleanup(self):
 | 
			
		||||
        self.monitor.stop()
 | 
			
		||||
							
								
								
									
										1
									
								
								tfw/components/pipe_io/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								tfw/components/pipe_io/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1 @@
 | 
			
		||||
from .pipe_io_handler import PipeIOHandler, PipeIOHandlerBase, TransformerPipeIOHandler, CommandHandler
 | 
			
		||||
							
								
								
									
										143
									
								
								tfw/components/pipe_io/pipe_io_handler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										143
									
								
								tfw/components/pipe_io/pipe_io_handler.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,143 @@
 | 
			
		||||
import logging
 | 
			
		||||
from abc import abstractmethod
 | 
			
		||||
from json import loads, dumps
 | 
			
		||||
from subprocess import run, PIPE, Popen
 | 
			
		||||
from functools import partial
 | 
			
		||||
from os import getpgid, killpg
 | 
			
		||||
from os.path import join
 | 
			
		||||
from signal import SIGTERM
 | 
			
		||||
from secrets import token_urlsafe
 | 
			
		||||
from threading import Thread
 | 
			
		||||
from contextlib import suppress
 | 
			
		||||
 | 
			
		||||
from .pipe_io_server import PipeIOServer, terminate_process_on_failure
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
LOG = logging.getLogger(__name__)
 | 
			
		||||
DEFAULT_PERMISSIONS = 0o600
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PipeIOHandlerBase:
 | 
			
		||||
    keys = ['']
 | 
			
		||||
 | 
			
		||||
    def __init__(self, in_pipe_path, out_pipe_path, permissions=DEFAULT_PERMISSIONS):
 | 
			
		||||
        self.server_connector = None
 | 
			
		||||
        self.pipe_io = CallbackPipeIOServer(
 | 
			
		||||
            in_pipe_path,
 | 
			
		||||
            out_pipe_path,
 | 
			
		||||
            self.handle_pipe_event,
 | 
			
		||||
            permissions
 | 
			
		||||
        )
 | 
			
		||||
        self.pipe_io.start()
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def handle_pipe_event(self, message_bytes):
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def cleanup(self):
 | 
			
		||||
        self.pipe_io.stop()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CallbackPipeIOServer(PipeIOServer):
 | 
			
		||||
    def __init__(self, in_pipe_path, out_pipe_path, callback, permissions):
 | 
			
		||||
        super().__init__(in_pipe_path, out_pipe_path, permissions)
 | 
			
		||||
        self.callback = callback
 | 
			
		||||
 | 
			
		||||
    def handle_message(self, message):
 | 
			
		||||
        try:
 | 
			
		||||
            self.callback(message)
 | 
			
		||||
        except:  # pylint: disable=bare-except
 | 
			
		||||
            LOG.exception('Failed to handle message %s from pipe %s!', message, self.in_pipe)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PipeIOHandler(PipeIOHandlerBase):
 | 
			
		||||
    def handle_event(self, message, _):
 | 
			
		||||
        json_bytes = dumps(message).encode()
 | 
			
		||||
        self.pipe_io.send_message(json_bytes)
 | 
			
		||||
 | 
			
		||||
    def handle_pipe_event(self, message_bytes):
 | 
			
		||||
        json = loads(message_bytes)
 | 
			
		||||
        self.server_connector.send_message(json)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TransformerPipeIOHandler(PipeIOHandlerBase):
 | 
			
		||||
    # pylint: disable=too-many-arguments
 | 
			
		||||
    def __init__(
 | 
			
		||||
            self, in_pipe_path, out_pipe_path,
 | 
			
		||||
            transform_in_cmd, transform_out_cmd,
 | 
			
		||||
            permissions=DEFAULT_PERMISSIONS
 | 
			
		||||
    ):
 | 
			
		||||
        self._transform_in = partial(self._transform_message, transform_in_cmd)
 | 
			
		||||
        self._transform_out = partial(self._transform_message, transform_out_cmd)
 | 
			
		||||
        super().__init__(in_pipe_path, out_pipe_path, permissions)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _transform_message(transform_cmd, message):
 | 
			
		||||
        proc = run(
 | 
			
		||||
            transform_cmd,
 | 
			
		||||
            input=message,
 | 
			
		||||
            stdout=PIPE,
 | 
			
		||||
            stderr=PIPE,
 | 
			
		||||
            shell=True
 | 
			
		||||
        )
 | 
			
		||||
        if proc.returncode == 0:
 | 
			
		||||
            return proc.stdout
 | 
			
		||||
        raise ValueError(f'Transforming message {message} failed!')
 | 
			
		||||
 | 
			
		||||
    def handle_event(self, message, _):
 | 
			
		||||
        json_bytes = dumps(message).encode()
 | 
			
		||||
        transformed_bytes = self._transform_out(json_bytes)
 | 
			
		||||
        if transformed_bytes:
 | 
			
		||||
            self.pipe_io.send_message(transformed_bytes)
 | 
			
		||||
 | 
			
		||||
    def handle_pipe_event(self, message_bytes):
 | 
			
		||||
        transformed_bytes = self._transform_in(message_bytes)
 | 
			
		||||
        if transformed_bytes:
 | 
			
		||||
            json_message = loads(transformed_bytes)
 | 
			
		||||
            self.server_connector.send_message(json_message)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CommandHandler(PipeIOHandler):
 | 
			
		||||
    def __init__(self, command, permissions=DEFAULT_PERMISSIONS):
 | 
			
		||||
        super().__init__(
 | 
			
		||||
            self._generate_tempfilename(),
 | 
			
		||||
            self._generate_tempfilename(),
 | 
			
		||||
            permissions
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self._proc_stdin = open(self.pipe_io.out_pipe, 'rb')
 | 
			
		||||
        self._proc_stdout = open(self.pipe_io.in_pipe, 'wb')
 | 
			
		||||
        self._proc = Popen(
 | 
			
		||||
            command, shell=True, executable='/bin/bash',
 | 
			
		||||
            stdin=self._proc_stdin, stdout=self._proc_stdout, stderr=PIPE,
 | 
			
		||||
            start_new_session=True
 | 
			
		||||
        )
 | 
			
		||||
        self._monitor_proc_thread = self._start_monitor_proc()
 | 
			
		||||
 | 
			
		||||
    def _generate_tempfilename(self):
 | 
			
		||||
        # pylint: disable=no-self-use
 | 
			
		||||
        random_filename = partial(token_urlsafe, 10)
 | 
			
		||||
        return join('/tmp', f'{type(self).__name__}.{random_filename()}')
 | 
			
		||||
 | 
			
		||||
    def _start_monitor_proc(self):
 | 
			
		||||
        thread = Thread(target=self._monitor_proc, daemon=True)
 | 
			
		||||
        thread.start()
 | 
			
		||||
        return thread
 | 
			
		||||
 | 
			
		||||
    @terminate_process_on_failure
 | 
			
		||||
    def _monitor_proc(self):
 | 
			
		||||
        return_code = self._proc.wait()
 | 
			
		||||
        if return_code == -int(SIGTERM):
 | 
			
		||||
            # supervisord asked the program to terminate, this is fine
 | 
			
		||||
            return
 | 
			
		||||
        if return_code != 0:
 | 
			
		||||
            _, stderr = self._proc.communicate()
 | 
			
		||||
            raise RuntimeError(f'Subprocess failed ({return_code})! Stderr:\n{stderr.decode()}')
 | 
			
		||||
 | 
			
		||||
    def cleanup(self):
 | 
			
		||||
        with suppress(ProcessLookupError):
 | 
			
		||||
            process_group_id = getpgid(self._proc.pid)
 | 
			
		||||
            killpg(process_group_id, SIGTERM)
 | 
			
		||||
            self._proc_stdin.close()
 | 
			
		||||
            self._proc_stdout.close()
 | 
			
		||||
        super().cleanup()
 | 
			
		||||
							
								
								
									
										2
									
								
								tfw/components/pipe_io/pipe_io_server/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								tfw/components/pipe_io/pipe_io_server/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,2 @@
 | 
			
		||||
from .pipe_io_server import PipeIOServer
 | 
			
		||||
from .terminate_process_on_failure import terminate_process_on_failure
 | 
			
		||||
							
								
								
									
										27
									
								
								tfw/components/pipe_io/pipe_io_server/deque.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								tfw/components/pipe_io/pipe_io_server/deque.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,27 @@
 | 
			
		||||
from collections import deque
 | 
			
		||||
from threading import Lock, Condition
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Deque:
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self._queue = deque()
 | 
			
		||||
 | 
			
		||||
        self._mutex = Lock()
 | 
			
		||||
        self._not_empty = Condition(self._mutex)
 | 
			
		||||
 | 
			
		||||
    def pop(self):
 | 
			
		||||
        with self._mutex:
 | 
			
		||||
            while not self._queue:
 | 
			
		||||
                self._not_empty.wait()
 | 
			
		||||
            return self._queue.pop()
 | 
			
		||||
 | 
			
		||||
    def push(self, item):
 | 
			
		||||
        self._push(item, self._queue.appendleft)
 | 
			
		||||
 | 
			
		||||
    def push_front(self, item):
 | 
			
		||||
        self._push(item, self._queue.append)
 | 
			
		||||
 | 
			
		||||
    def _push(self, item, put_method):
 | 
			
		||||
        with self._mutex:
 | 
			
		||||
            put_method(item)
 | 
			
		||||
            self._not_empty.notify()
 | 
			
		||||
							
								
								
									
										16
									
								
								tfw/components/pipe_io/pipe_io_server/pipe.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								tfw/components/pipe_io/pipe_io_server/pipe.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,16 @@
 | 
			
		||||
from os import mkfifo, remove, chmod
 | 
			
		||||
from os.path import exists
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Pipe:
 | 
			
		||||
    def __init__(self, path):
 | 
			
		||||
        self.path = path
 | 
			
		||||
 | 
			
		||||
    def recreate(self, permissions):
 | 
			
		||||
        self.remove()
 | 
			
		||||
        mkfifo(self.path)
 | 
			
		||||
        chmod(self.path, permissions)  # use chmod to ignore umask
 | 
			
		||||
 | 
			
		||||
    def remove(self):
 | 
			
		||||
        if exists(self.path):
 | 
			
		||||
            remove(self.path)
 | 
			
		||||
							
								
								
									
										73
									
								
								tfw/components/pipe_io/pipe_io_server/pipe_io_server.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										73
									
								
								tfw/components/pipe_io/pipe_io_server/pipe_io_server.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,73 @@
 | 
			
		||||
from abc import ABC, abstractmethod
 | 
			
		||||
from threading import Thread, Event
 | 
			
		||||
from typing import Callable
 | 
			
		||||
 | 
			
		||||
from .pipe_reader_thread import PipeReaderThread
 | 
			
		||||
from .pipe_writer_thread import PipeWriterThread
 | 
			
		||||
from .pipe import Pipe
 | 
			
		||||
from .terminate_process_on_failure import terminate_process_on_failure
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PipeIOServer(ABC, Thread):
 | 
			
		||||
    def __init__(self, in_pipe=None, out_pipe=None, permissions=0o600):
 | 
			
		||||
        super().__init__(daemon=True)
 | 
			
		||||
        self._in_pipe, self._out_pipe = in_pipe, out_pipe
 | 
			
		||||
        self._create_pipes(permissions)
 | 
			
		||||
        self._stop_event = Event()
 | 
			
		||||
        self._reader_thread, self._writer_thread = self._create_io_threads()
 | 
			
		||||
        self._io_threads = (self._reader_thread, self._writer_thread)
 | 
			
		||||
        self._on_stop = lambda: None
 | 
			
		||||
 | 
			
		||||
    def _create_pipes(self, permissions):
 | 
			
		||||
        Pipe(self.in_pipe).recreate(permissions)
 | 
			
		||||
        Pipe(self.out_pipe).recreate(permissions)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def in_pipe(self):
 | 
			
		||||
        return self._in_pipe
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def out_pipe(self):
 | 
			
		||||
        return self._out_pipe
 | 
			
		||||
 | 
			
		||||
    def _create_io_threads(self):
 | 
			
		||||
        reader_thread = PipeReaderThread(self.in_pipe, self._stop_event, self.handle_message)
 | 
			
		||||
        writer_thread = PipeWriterThread(self.out_pipe, self._stop_event)
 | 
			
		||||
        return reader_thread, writer_thread
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def handle_message(self, message):
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def send_message(self, message):
 | 
			
		||||
        self._writer_thread.write(message)
 | 
			
		||||
 | 
			
		||||
    @terminate_process_on_failure
 | 
			
		||||
    def run(self):
 | 
			
		||||
        for thread in self._io_threads:
 | 
			
		||||
            thread.start()
 | 
			
		||||
        self._stop_event.wait()
 | 
			
		||||
        self._stop_threads()
 | 
			
		||||
 | 
			
		||||
    def stop(self):
 | 
			
		||||
        self._stop_event.set()
 | 
			
		||||
        if self.is_alive():
 | 
			
		||||
            self.join()
 | 
			
		||||
 | 
			
		||||
    def _stop_threads(self):
 | 
			
		||||
        for thread in self._io_threads:
 | 
			
		||||
            if thread.is_alive():
 | 
			
		||||
                thread.stop()
 | 
			
		||||
        Pipe(self.in_pipe).remove()
 | 
			
		||||
        Pipe(self.out_pipe).remove()
 | 
			
		||||
        self._on_stop()
 | 
			
		||||
 | 
			
		||||
    def _set_on_stop(self, value):
 | 
			
		||||
        if not isinstance(value, Callable):
 | 
			
		||||
            raise ValueError("Supplied object is not callable!")
 | 
			
		||||
        self._on_stop = value
 | 
			
		||||
 | 
			
		||||
    on_stop = property(fset=_set_on_stop)
 | 
			
		||||
 | 
			
		||||
    def wait(self):
 | 
			
		||||
        self._stop_event.wait()
 | 
			
		||||
							
								
								
									
										44
									
								
								tfw/components/pipe_io/pipe_io_server/pipe_reader_thread.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								tfw/components/pipe_io/pipe_io_server/pipe_reader_thread.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,44 @@
 | 
			
		||||
from contextlib import suppress
 | 
			
		||||
from os import open as osopen
 | 
			
		||||
from os import write, close, O_WRONLY, O_NONBLOCK
 | 
			
		||||
from threading import Thread
 | 
			
		||||
 | 
			
		||||
from .terminate_process_on_failure import terminate_process_on_failure
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PipeReaderThread(Thread):
 | 
			
		||||
    eof = b''
 | 
			
		||||
    stop_sequence = b'stop_reading\n'
 | 
			
		||||
 | 
			
		||||
    def __init__(self, pipe_path, stop_event, message_handler):
 | 
			
		||||
        super().__init__(daemon=True)
 | 
			
		||||
        self._message_handler = message_handler
 | 
			
		||||
        self._pipe_path = pipe_path
 | 
			
		||||
        self._stop_event = stop_event
 | 
			
		||||
 | 
			
		||||
    @terminate_process_on_failure
 | 
			
		||||
    def run(self):
 | 
			
		||||
        with self._open() as pipe:
 | 
			
		||||
            while True:
 | 
			
		||||
                message = pipe.readline()
 | 
			
		||||
                if message == self.stop_sequence:
 | 
			
		||||
                    self._stop_event.set()
 | 
			
		||||
                    break
 | 
			
		||||
                if message == self.eof:
 | 
			
		||||
                    self._open().close()
 | 
			
		||||
                    continue
 | 
			
		||||
                self._message_handler(message[:-1])
 | 
			
		||||
 | 
			
		||||
    def _open(self):
 | 
			
		||||
        return open(self._pipe_path, 'rb')
 | 
			
		||||
 | 
			
		||||
    def stop(self):
 | 
			
		||||
        while self.is_alive():
 | 
			
		||||
            self._unblock()
 | 
			
		||||
        self.join()
 | 
			
		||||
 | 
			
		||||
    def _unblock(self):
 | 
			
		||||
        with suppress(OSError):
 | 
			
		||||
            fd = osopen(self._pipe_path, O_WRONLY | O_NONBLOCK)
 | 
			
		||||
            write(fd, self.stop_sequence)
 | 
			
		||||
            close(fd)
 | 
			
		||||
							
								
								
									
										50
									
								
								tfw/components/pipe_io/pipe_io_server/pipe_writer_thread.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								tfw/components/pipe_io/pipe_io_server/pipe_writer_thread.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,50 @@
 | 
			
		||||
from contextlib import suppress
 | 
			
		||||
from os import O_NONBLOCK, O_RDONLY, close
 | 
			
		||||
from os import open as osopen
 | 
			
		||||
from threading import Thread
 | 
			
		||||
 | 
			
		||||
from .terminate_process_on_failure import terminate_process_on_failure
 | 
			
		||||
from .deque import Deque
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PipeWriterThread(Thread):
 | 
			
		||||
    def __init__(self, pipe_path, stop_event):
 | 
			
		||||
        super().__init__(daemon=True)
 | 
			
		||||
        self._pipe_path = pipe_path
 | 
			
		||||
        self._stop_event = stop_event
 | 
			
		||||
        self._write_queue = Deque()
 | 
			
		||||
 | 
			
		||||
    def write(self, message):
 | 
			
		||||
        self._write_queue.push(message)
 | 
			
		||||
 | 
			
		||||
    @terminate_process_on_failure
 | 
			
		||||
    def run(self):
 | 
			
		||||
        with self._open() as pipe:
 | 
			
		||||
            while True:
 | 
			
		||||
                message = self._write_queue.pop()
 | 
			
		||||
                if message is None:
 | 
			
		||||
                    self._stop_event.set()
 | 
			
		||||
                    break
 | 
			
		||||
                try:
 | 
			
		||||
                    pipe.write(message + b'\n')
 | 
			
		||||
                    pipe.flush()
 | 
			
		||||
                except BrokenPipeError:
 | 
			
		||||
                    try:  # pipe was reopened, close() flushed the message
 | 
			
		||||
                        pipe.close()
 | 
			
		||||
                    except BrokenPipeError:  # close() discarded the message
 | 
			
		||||
                        self._write_queue.push_front(message)
 | 
			
		||||
                    pipe = self._open()
 | 
			
		||||
 | 
			
		||||
    def _open(self):
 | 
			
		||||
        return open(self._pipe_path, 'wb')
 | 
			
		||||
 | 
			
		||||
    def stop(self):
 | 
			
		||||
        while self.is_alive():
 | 
			
		||||
            self._unblock()
 | 
			
		||||
        self.join()
 | 
			
		||||
 | 
			
		||||
    def _unblock(self):
 | 
			
		||||
        with suppress(OSError):
 | 
			
		||||
            fd = osopen(self._pipe_path, O_RDONLY | O_NONBLOCK)
 | 
			
		||||
            self._write_queue.push_front(None)
 | 
			
		||||
            close(fd)
 | 
			
		||||
@@ -0,0 +1,15 @@
 | 
			
		||||
from functools import wraps
 | 
			
		||||
from os import kill, getpid
 | 
			
		||||
from signal import SIGTERM
 | 
			
		||||
from traceback import print_exc
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def terminate_process_on_failure(fun):
 | 
			
		||||
    @wraps(fun)
 | 
			
		||||
    def wrapper(*args, **kwargs):
 | 
			
		||||
        try:
 | 
			
		||||
            return fun(*args, **kwargs)
 | 
			
		||||
        except:  # pylint: disable=bare-except
 | 
			
		||||
            print_exc()
 | 
			
		||||
            kill(getpid(), SIGTERM)
 | 
			
		||||
    return wrapper
 | 
			
		||||
							
								
								
									
										2
									
								
								tfw/components/process_management/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								tfw/components/process_management/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,2 @@
 | 
			
		||||
from .process_handler import ProcessHandler
 | 
			
		||||
from .process_log_handler import ProcessLogHandler
 | 
			
		||||
							
								
								
									
										45
									
								
								tfw/components/process_management/log_inotify_observer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								tfw/components/process_management/log_inotify_observer.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,45 @@
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
from tfw.internals.networking import Scope
 | 
			
		||||
from tfw.internals.inotify import InotifyObserver
 | 
			
		||||
 | 
			
		||||
from .supervisor import ProcessLogManager
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LogInotifyObserver(InotifyObserver, ProcessLogManager):
 | 
			
		||||
    def __init__(self, server_connector, supervisor_uri, process_name, log_tail=0):
 | 
			
		||||
        self._prevent_log_recursion()
 | 
			
		||||
        self._server_connector = server_connector
 | 
			
		||||
        self._process_name = process_name
 | 
			
		||||
        self.log_tail = log_tail
 | 
			
		||||
        self._procinfo = None
 | 
			
		||||
        ProcessLogManager.__init__(self, supervisor_uri)
 | 
			
		||||
        InotifyObserver.__init__(self, self._get_logfiles())
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _prevent_log_recursion():
 | 
			
		||||
        # This is done to prevent inotify event logs triggering themselves (infinite log recursion)
 | 
			
		||||
        logging.getLogger('watchdog.observers.inotify_buffer').propagate = False
 | 
			
		||||
 | 
			
		||||
    def _get_logfiles(self):
 | 
			
		||||
        self._procinfo = self.supervisor.getProcessInfo(self._process_name)
 | 
			
		||||
        return self._procinfo['stdout_logfile'], self._procinfo['stderr_logfile']
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def process_name(self):
 | 
			
		||||
        return self._process_name
 | 
			
		||||
 | 
			
		||||
    @process_name.setter
 | 
			
		||||
    def process_name(self, process_name):
 | 
			
		||||
        self._process_name = process_name
 | 
			
		||||
        self.paths = self._get_logfiles()
 | 
			
		||||
 | 
			
		||||
    def on_modified(self, event):
 | 
			
		||||
        self._server_connector.send_message({
 | 
			
		||||
            'key': 'processlog',
 | 
			
		||||
            'data': {
 | 
			
		||||
                'command': 'new_log',
 | 
			
		||||
                'stdout': self.read_stdout(self.process_name, tail=self.log_tail),
 | 
			
		||||
                'stderr': self.read_stderr(self.process_name, tail=self.log_tail)
 | 
			
		||||
            }
 | 
			
		||||
        }, Scope.BROADCAST)
 | 
			
		||||
							
								
								
									
										54
									
								
								tfw/components/process_management/process_handler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								tfw/components/process_management/process_handler.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,54 @@
 | 
			
		||||
import logging
 | 
			
		||||
from xmlrpc.client import Fault as SupervisorFault
 | 
			
		||||
 | 
			
		||||
from tfw.internals.networking import Scope
 | 
			
		||||
 | 
			
		||||
from .supervisor import ProcessManager, ProcessLogManager
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
LOG = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ProcessHandler(ProcessManager, ProcessLogManager):
 | 
			
		||||
    keys = ['processmanager']
 | 
			
		||||
    """
 | 
			
		||||
    Event handler that can manage processes managed by supervisor.
 | 
			
		||||
 | 
			
		||||
    This EventHandler accepts messages that have a data['command'] key specifying
 | 
			
		||||
    a command to be executed.
 | 
			
		||||
    Every message must contain a data['process_name'] field with the name of the
 | 
			
		||||
    process to manage. This is the name specified in supervisor config files like so:
 | 
			
		||||
    [program:someprogram]
 | 
			
		||||
 | 
			
		||||
    Commands available: start, stop, restart, readlog
 | 
			
		||||
    (the names are as self-documenting as it gets)
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, *, supervisor_uri, log_tail=0):
 | 
			
		||||
        ProcessManager.__init__(self, supervisor_uri)
 | 
			
		||||
        ProcessLogManager.__init__(self, supervisor_uri)
 | 
			
		||||
        self.log_tail = log_tail
 | 
			
		||||
        self.commands = {
 | 
			
		||||
            'start': self.start_process,
 | 
			
		||||
            'stop': self.stop_process,
 | 
			
		||||
            'restart': self.restart_process
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    def handle_event(self, message, server_connector):
 | 
			
		||||
        try:
 | 
			
		||||
            data = message['data']
 | 
			
		||||
            try:
 | 
			
		||||
                self.commands[data['command']](data['process_name'])
 | 
			
		||||
            except SupervisorFault as fault:
 | 
			
		||||
                message['data']['error'] = fault.faultString
 | 
			
		||||
            finally:
 | 
			
		||||
                message['data']['stdout'] = self.read_stdout(
 | 
			
		||||
                    data['process_name'],
 | 
			
		||||
                    self.log_tail
 | 
			
		||||
                )
 | 
			
		||||
                message['data']['stderr'] = self.read_stderr(
 | 
			
		||||
                    data['process_name'],
 | 
			
		||||
                    self.log_tail
 | 
			
		||||
                )
 | 
			
		||||
            server_connector.send_message(message, scope=Scope.WEBSOCKET)
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            LOG.error('IGNORING MESSAGE: Invalid message received: %s', message)
 | 
			
		||||
							
								
								
									
										69
									
								
								tfw/components/process_management/process_log_handler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								tfw/components/process_management/process_log_handler.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,69 @@
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
from .log_inotify_observer import LogInotifyObserver
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
LOG = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ProcessLogHandler:
 | 
			
		||||
    keys = ['logmonitor']
 | 
			
		||||
    """
 | 
			
		||||
    Monitors the output of a supervisor process (stdout, stderr) and
 | 
			
		||||
    sends the results to the frontend.
 | 
			
		||||
 | 
			
		||||
    Accepts messages that have a data['command'] key specifying
 | 
			
		||||
    a command to be executed.
 | 
			
		||||
 | 
			
		||||
    The API of each command is documented in their respective handler.
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, *, process_name, supervisor_uri, log_tail=0):
 | 
			
		||||
        self.server_connector = None
 | 
			
		||||
        self.process_name = process_name
 | 
			
		||||
        self._supervisor_uri = supervisor_uri
 | 
			
		||||
        self._initial_log_tail = log_tail
 | 
			
		||||
        self._monitor = None
 | 
			
		||||
 | 
			
		||||
        self.command_handlers = {
 | 
			
		||||
            'process_name': self.handle_process_name,
 | 
			
		||||
            'log_tail':     self.handle_log_tail
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    def start(self):
 | 
			
		||||
        self._monitor = LogInotifyObserver(
 | 
			
		||||
            server_connector=self.server_connector,
 | 
			
		||||
            supervisor_uri=self._supervisor_uri,
 | 
			
		||||
            process_name=self.process_name,
 | 
			
		||||
            log_tail=self._initial_log_tail
 | 
			
		||||
        )
 | 
			
		||||
        self._monitor.start()
 | 
			
		||||
 | 
			
		||||
    def handle_event(self, message, _):
 | 
			
		||||
        try:
 | 
			
		||||
            data = message['data']
 | 
			
		||||
            self.command_handlers[data['command']](data)
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            LOG.error('IGNORING MESSAGE: Invalid message received: %s', message)
 | 
			
		||||
 | 
			
		||||
    def handle_process_name(self, data):
 | 
			
		||||
        """
 | 
			
		||||
        Changes the monitored process.
 | 
			
		||||
 | 
			
		||||
        :param data: TFW message data containing 'value'
 | 
			
		||||
                     (name of the process to monitor)
 | 
			
		||||
        """
 | 
			
		||||
        self._monitor.process_name = data['value']
 | 
			
		||||
 | 
			
		||||
    def handle_log_tail(self, data):
 | 
			
		||||
        """
 | 
			
		||||
        Sets tail length of the log the monitor will send
 | 
			
		||||
        to the frontend (the monitor will send back the last
 | 
			
		||||
        'value' characters of the log).
 | 
			
		||||
 | 
			
		||||
        :param data: TFW message data containing 'value'
 | 
			
		||||
                     (new tail length)
 | 
			
		||||
        """
 | 
			
		||||
        self._monitor.log_tail = data['value']
 | 
			
		||||
 | 
			
		||||
    def cleanup(self):
 | 
			
		||||
        self._monitor.stop()
 | 
			
		||||
							
								
								
									
										36
									
								
								tfw/components/process_management/supervisor.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								tfw/components/process_management/supervisor.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,36 @@
 | 
			
		||||
from os import remove
 | 
			
		||||
from contextlib import suppress
 | 
			
		||||
import xmlrpc.client
 | 
			
		||||
from xmlrpc.client import Fault as SupervisorFault
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SupervisorBase:
 | 
			
		||||
    def __init__(self, supervisor_uri):
 | 
			
		||||
        self.supervisor = xmlrpc.client.ServerProxy(supervisor_uri).supervisor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ProcessManager(SupervisorBase):
 | 
			
		||||
    def stop_process(self, process_name):
 | 
			
		||||
        with suppress(SupervisorFault):
 | 
			
		||||
            self.supervisor.stopProcess(process_name)
 | 
			
		||||
 | 
			
		||||
    def start_process(self, process_name):
 | 
			
		||||
        self.supervisor.startProcess(process_name)
 | 
			
		||||
 | 
			
		||||
    def restart_process(self, process_name):
 | 
			
		||||
        self.stop_process(process_name)
 | 
			
		||||
        self.start_process(process_name)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ProcessLogManager(SupervisorBase):
 | 
			
		||||
    def read_stdout(self, process_name, tail=0):
 | 
			
		||||
        return self.supervisor.readProcessStdoutLog(process_name, -tail, 0)
 | 
			
		||||
 | 
			
		||||
    def read_stderr(self, process_name, tail=0):
 | 
			
		||||
        return self.supervisor.readProcessStderrLog(process_name, -tail, 0)
 | 
			
		||||
 | 
			
		||||
    def clear_logs(self, process_name):
 | 
			
		||||
        for logfile in ('stdout_logfile', 'stderr_logfile'):
 | 
			
		||||
            with suppress(FileNotFoundError):
 | 
			
		||||
                remove(self.supervisor.getProcessInfo(process_name)[logfile])
 | 
			
		||||
        self.supervisor.clearProcessLogs(process_name)
 | 
			
		||||
							
								
								
									
										1
									
								
								tfw/components/snapshots/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								tfw/components/snapshots/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1 @@
 | 
			
		||||
from .snapshot_handler import SnapshotHandler
 | 
			
		||||
							
								
								
									
										86
									
								
								tfw/components/snapshots/snapshot_handler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								tfw/components/snapshots/snapshot_handler.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,86 @@
 | 
			
		||||
import logging
 | 
			
		||||
from os.path import join as joinpath
 | 
			
		||||
from os.path import basename
 | 
			
		||||
from os import makedirs
 | 
			
		||||
from datetime import datetime
 | 
			
		||||
 | 
			
		||||
from dateutil import parser as dateparser
 | 
			
		||||
 | 
			
		||||
from tfw.internals.networking import Scope
 | 
			
		||||
 | 
			
		||||
from .snapshot_provider import SnapshotProvider
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
LOG = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SnapshotHandler:
 | 
			
		||||
    keys = ['snapshot']
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *, directories, snapshots_dir, exclude_unix_patterns=None):
 | 
			
		||||
        self._snapshots_dir = snapshots_dir
 | 
			
		||||
        self.snapshot_providers = {}
 | 
			
		||||
        self._exclude_unix_patterns = exclude_unix_patterns
 | 
			
		||||
        self.init_snapshot_providers(directories)
 | 
			
		||||
 | 
			
		||||
        self.command_handlers = {
 | 
			
		||||
            'take_snapshot': self.handle_take_snapshot,
 | 
			
		||||
            'restore_snapshot': self.handle_restore_snapshot,
 | 
			
		||||
            'exclude': self.handle_exclude
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    def init_snapshot_providers(self, directories):
 | 
			
		||||
        for index, directory in enumerate(directories):
 | 
			
		||||
            git_dir = self.init_git_dir(index, directory)
 | 
			
		||||
            self.snapshot_providers[directory] = SnapshotProvider(
 | 
			
		||||
                directory,
 | 
			
		||||
                git_dir,
 | 
			
		||||
                self._exclude_unix_patterns
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def init_git_dir(self, index, directory):
 | 
			
		||||
        git_dir = joinpath(
 | 
			
		||||
            self._snapshots_dir,
 | 
			
		||||
            f'{basename(directory)}-{index}'
 | 
			
		||||
        )
 | 
			
		||||
        makedirs(git_dir, exist_ok=True)
 | 
			
		||||
        return git_dir
 | 
			
		||||
 | 
			
		||||
    def handle_event(self, message, server_connector):
 | 
			
		||||
        try:
 | 
			
		||||
            data = message['data']
 | 
			
		||||
            message['data'] = self.command_handlers[data['command']](data)
 | 
			
		||||
            server_connector.send_message(message, scope=Scope.WEBSOCKET)
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            LOG.error('IGNORING MESSAGE: Invalid message received: %s', message)
 | 
			
		||||
 | 
			
		||||
    def handle_take_snapshot(self, data):
 | 
			
		||||
        LOG.debug('Taking snapshots of directories %s', self.snapshot_providers.keys())
 | 
			
		||||
        for provider in self.snapshot_providers.values():
 | 
			
		||||
            provider.take_snapshot()
 | 
			
		||||
        return data
 | 
			
		||||
 | 
			
		||||
    def handle_restore_snapshot(self, data):
 | 
			
		||||
        date = dateparser.parse(
 | 
			
		||||
            data.get(
 | 
			
		||||
                'value',
 | 
			
		||||
                datetime.now().isoformat()
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        LOG.debug(
 | 
			
		||||
            'Restoring snapshots (@ %s) of directories %s',
 | 
			
		||||
            date,
 | 
			
		||||
            self.snapshot_providers.keys()
 | 
			
		||||
        )
 | 
			
		||||
        for provider in self.snapshot_providers.values():
 | 
			
		||||
            provider.restore_snapshot(date)
 | 
			
		||||
        return data
 | 
			
		||||
 | 
			
		||||
    def handle_exclude(self, data):
 | 
			
		||||
        exclude_unix_patterns = data['value']
 | 
			
		||||
        if not isinstance(exclude_unix_patterns, list):
 | 
			
		||||
            raise KeyError
 | 
			
		||||
 | 
			
		||||
        for provider in self.snapshot_providers.values():
 | 
			
		||||
            provider.exclude = exclude_unix_patterns
 | 
			
		||||
        return data
 | 
			
		||||
							
								
								
									
										221
									
								
								tfw/components/snapshots/snapshot_provider.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										221
									
								
								tfw/components/snapshots/snapshot_provider.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,221 @@
 | 
			
		||||
import re
 | 
			
		||||
from subprocess import run, CalledProcessError, PIPE
 | 
			
		||||
from getpass import getuser
 | 
			
		||||
from os.path import isdir
 | 
			
		||||
from os.path import join as joinpath
 | 
			
		||||
from uuid import uuid4
 | 
			
		||||
 | 
			
		||||
from dateutil import parser as dateparser
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SnapshotProvider:
 | 
			
		||||
    def __init__(self, directory, git_dir, exclude_unix_patterns=None):
 | 
			
		||||
        self._classname = self.__class__.__name__
 | 
			
		||||
        author = f'{getuser()} via TFW {self._classname}'
 | 
			
		||||
        self.gitenv = {
 | 
			
		||||
            'GIT_DIR': git_dir,
 | 
			
		||||
            'GIT_WORK_TREE': directory,
 | 
			
		||||
            'GIT_AUTHOR_NAME': author,
 | 
			
		||||
            'GIT_AUTHOR_EMAIL': '',
 | 
			
		||||
            'GIT_COMMITTER_NAME': author,
 | 
			
		||||
            'GIT_COMMITTER_EMAIL': '',
 | 
			
		||||
            'GIT_PAGER': 'cat'
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        self._init_repo()
 | 
			
		||||
        self.__last_valid_branch = self._branch
 | 
			
		||||
        if exclude_unix_patterns:
 | 
			
		||||
            self.exclude = exclude_unix_patterns
 | 
			
		||||
 | 
			
		||||
    def _init_repo(self):
 | 
			
		||||
        self._check_environment()
 | 
			
		||||
 | 
			
		||||
        if not self._repo_is_initialized:
 | 
			
		||||
            self._run(('git', 'init'))
 | 
			
		||||
 | 
			
		||||
        if self._number_of_commits == 0:
 | 
			
		||||
            try:
 | 
			
		||||
                self._snapshot()
 | 
			
		||||
            except CalledProcessError:
 | 
			
		||||
                raise EnvironmentError(f'{self._classname} cannot init on empty directories!')
 | 
			
		||||
 | 
			
		||||
        self._check_head_not_detached()
 | 
			
		||||
 | 
			
		||||
    def _check_environment(self):
 | 
			
		||||
        if not isdir(self.gitenv['GIT_DIR']) or not isdir(self.gitenv['GIT_WORK_TREE']):
 | 
			
		||||
            raise EnvironmentError(f'{self._classname}: "directory" and "git_dir" must exist!')
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def _repo_is_initialized(self):
 | 
			
		||||
        return self._run(
 | 
			
		||||
            ('git', 'status'),
 | 
			
		||||
            check=False
 | 
			
		||||
        ).returncode == 0
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def _number_of_commits(self):
 | 
			
		||||
        return int(
 | 
			
		||||
            self._get_stdout((
 | 
			
		||||
                'git', 'rev-list',
 | 
			
		||||
                '--all',
 | 
			
		||||
                '--count'
 | 
			
		||||
            ))
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _snapshot(self):
 | 
			
		||||
        self._run((
 | 
			
		||||
            'git', 'add',
 | 
			
		||||
            '-A'
 | 
			
		||||
        ))
 | 
			
		||||
        try:
 | 
			
		||||
            self._get_stdout((
 | 
			
		||||
                'git', 'commit',
 | 
			
		||||
                '-m', 'Snapshot'
 | 
			
		||||
            ))
 | 
			
		||||
        except CalledProcessError as err:
 | 
			
		||||
            if b'nothing to commit, working tree clean' not in err.output:
 | 
			
		||||
                raise
 | 
			
		||||
 | 
			
		||||
    def _check_head_not_detached(self):
 | 
			
		||||
        if self._head_detached:
 | 
			
		||||
            raise EnvironmentError(f'{self._classname} cannot init from detached HEAD state!')
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def _head_detached(self):
 | 
			
		||||
        return self._branch == 'HEAD'
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def _branch(self):
 | 
			
		||||
        return self._get_stdout((
 | 
			
		||||
            'git', 'rev-parse',
 | 
			
		||||
            '--abbrev-ref', 'HEAD'
 | 
			
		||||
        ))
 | 
			
		||||
 | 
			
		||||
    def _get_stdout(self, *args, **kwargs):
 | 
			
		||||
        kwargs['stdout'] = PIPE
 | 
			
		||||
        kwargs['stderr'] = PIPE
 | 
			
		||||
        stdout_bytes = self._run(*args, **kwargs).stdout
 | 
			
		||||
        return stdout_bytes.decode().rstrip('\n')
 | 
			
		||||
 | 
			
		||||
    def _run(self, *args, **kwargs):
 | 
			
		||||
        if 'check' not in kwargs:
 | 
			
		||||
            kwargs['check'] = True
 | 
			
		||||
        if 'env' not in kwargs:
 | 
			
		||||
            kwargs['env'] = self.gitenv
 | 
			
		||||
        return run(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def exclude(self):
 | 
			
		||||
        with open(self._exclude_path, 'r') as ofile:
 | 
			
		||||
            return ofile.read()
 | 
			
		||||
 | 
			
		||||
    @exclude.setter
 | 
			
		||||
    def exclude(self, exclude_patterns):
 | 
			
		||||
        with open(self._exclude_path, 'w') as ifile:
 | 
			
		||||
            ifile.write('\n'.join(exclude_patterns))
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def _exclude_path(self):
 | 
			
		||||
        return joinpath(
 | 
			
		||||
            self.gitenv['GIT_DIR'],
 | 
			
		||||
            'info',
 | 
			
		||||
            'exclude'
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def take_snapshot(self):
 | 
			
		||||
        if self._head_detached:
 | 
			
		||||
            self._checkout_new_branch_from_head()
 | 
			
		||||
        self._snapshot()
 | 
			
		||||
 | 
			
		||||
    def _checkout_new_branch_from_head(self):
 | 
			
		||||
        branch_name = str(uuid4())
 | 
			
		||||
        self._run((
 | 
			
		||||
            'git', 'branch',
 | 
			
		||||
            branch_name
 | 
			
		||||
        ))
 | 
			
		||||
        self._checkout(branch_name)
 | 
			
		||||
 | 
			
		||||
    def _checkout(self, what):
 | 
			
		||||
        self._run((
 | 
			
		||||
            'git', 'checkout',
 | 
			
		||||
            what
 | 
			
		||||
        ))
 | 
			
		||||
 | 
			
		||||
    def restore_snapshot(self, date):
 | 
			
		||||
        commit = self._get_commit_from_timestamp(date)
 | 
			
		||||
        branch = self._last_valid_branch
 | 
			
		||||
        if commit == self._latest_commit_on_branch(branch):
 | 
			
		||||
            commit = branch
 | 
			
		||||
        self._checkout(commit)
 | 
			
		||||
 | 
			
		||||
    def _get_commit_from_timestamp(self, date):
 | 
			
		||||
        commit = self._get_stdout((
 | 
			
		||||
            'git', 'rev-list',
 | 
			
		||||
            '--date=iso',
 | 
			
		||||
            '-n', '1',
 | 
			
		||||
            f'--before="{date.isoformat()}"',
 | 
			
		||||
            self._last_valid_branch
 | 
			
		||||
        ))
 | 
			
		||||
        if not commit:
 | 
			
		||||
            commit = self._get_oldest_parent_of_head()
 | 
			
		||||
        return commit
 | 
			
		||||
 | 
			
		||||
    def _get_oldest_parent_of_head(self):
 | 
			
		||||
        return self._get_stdout((
 | 
			
		||||
            'git',
 | 
			
		||||
            'rev-list',
 | 
			
		||||
            '--max-parents=0',
 | 
			
		||||
            'HEAD'
 | 
			
		||||
        ))
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def _last_valid_branch(self):
 | 
			
		||||
        if not self._head_detached:
 | 
			
		||||
            self.__last_valid_branch = self._branch
 | 
			
		||||
        return self.__last_valid_branch
 | 
			
		||||
 | 
			
		||||
    def _latest_commit_on_branch(self, branch):
 | 
			
		||||
        return self._get_stdout((
 | 
			
		||||
            'git', 'log',
 | 
			
		||||
            '-n', '1',
 | 
			
		||||
            '--pretty=format:%H',
 | 
			
		||||
            branch
 | 
			
		||||
        ))
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def all_timelines(self):
 | 
			
		||||
        return self._branches
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def _branches(self):
 | 
			
		||||
        git_branch_output = self._get_stdout(('git', 'branch'))
 | 
			
		||||
        regex_pattern = re.compile(r'(?:[^\S\n]|[*])')  # matches '*' and non-newline whitespace chars
 | 
			
		||||
        return re.sub(regex_pattern, '', git_branch_output).splitlines()
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def timeline(self):
 | 
			
		||||
        return self._last_valid_branch
 | 
			
		||||
 | 
			
		||||
    @timeline.setter
 | 
			
		||||
    def timeline(self, value):
 | 
			
		||||
        self._checkout(value)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def snapshots(self):
 | 
			
		||||
        return self._pretty_log_branch()
 | 
			
		||||
 | 
			
		||||
    def _pretty_log_branch(self):
 | 
			
		||||
        git_log_output = self._get_stdout((
 | 
			
		||||
            'git', 'log',
 | 
			
		||||
            '--pretty=%H@%aI'
 | 
			
		||||
        ))
 | 
			
		||||
 | 
			
		||||
        commits = []
 | 
			
		||||
        for line in git_log_output.splitlines():
 | 
			
		||||
            commit_hash, timestamp = line.split('@')
 | 
			
		||||
            commits.append({
 | 
			
		||||
                'hash': commit_hash,
 | 
			
		||||
                'timestamp': dateparser.parse(timestamp)
 | 
			
		||||
            })
 | 
			
		||||
 | 
			
		||||
        return commits
 | 
			
		||||
							
								
								
									
										3
									
								
								tfw/components/terminal/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								tfw/components/terminal/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,3 @@
 | 
			
		||||
from .terminal_handler import TerminalHandler
 | 
			
		||||
from .terminal_commands_handler import TerminalCommandsHandler
 | 
			
		||||
from .commands_equal import CommandsEqual
 | 
			
		||||
							
								
								
									
										107
									
								
								tfw/components/terminal/commands_equal.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										107
									
								
								tfw/components/terminal/commands_equal.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,107 @@
 | 
			
		||||
from shlex import split
 | 
			
		||||
from re import search
 | 
			
		||||
 | 
			
		||||
from tfw.internals.lazy import lazy_property
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CommandsEqual:
 | 
			
		||||
    # pylint: disable=too-many-arguments
 | 
			
		||||
    """
 | 
			
		||||
    This class is useful for comparing executed commands with
 | 
			
		||||
    excepted commands (i.e. when triggering a state change when
 | 
			
		||||
    the correct command is executed).
 | 
			
		||||
 | 
			
		||||
    Note that in most cases you should test the changes
 | 
			
		||||
    caused by the commands instead of just checking command history
 | 
			
		||||
    (stuff can be done in countless ways and preparing for every
 | 
			
		||||
    single case is impossible). This should only be used when
 | 
			
		||||
    testing the changes would be very difficult, like when
 | 
			
		||||
    explaining stuff with cli tools and such.
 | 
			
		||||
 | 
			
		||||
    This class implicitly converts to bool, use it like
 | 
			
		||||
    if CommandsEqual(...): ...
 | 
			
		||||
 | 
			
		||||
    It tries detecting differing command parameter orders with similar
 | 
			
		||||
    semantics and provides fuzzy logic options.
 | 
			
		||||
    The rationale behind this is that a few false positives
 | 
			
		||||
    are better than only accepting a single version of a command
 | 
			
		||||
    (i.e. using ==).
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(
 | 
			
		||||
            self, command_1, command_2,
 | 
			
		||||
            fuzzyness=1, begin_similarly=True,
 | 
			
		||||
            include_patterns=None, exclude_patterns=None
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        :param command_1: Compared command 1
 | 
			
		||||
        :param command_2: Compared command 2
 | 
			
		||||
        :param fuzzyness: float between 0 and 1.
 | 
			
		||||
                          the percentage of arguments required to
 | 
			
		||||
                          match between commands to result in True.
 | 
			
		||||
                          i.e 1 means 100% - all arguments need to be
 | 
			
		||||
                          present in both commands, while 0.75
 | 
			
		||||
                          would mean 75% - in case of 4 arguments
 | 
			
		||||
                          1 could differ between the commands.
 | 
			
		||||
        :param begin_similarly: bool, the first word of the commands
 | 
			
		||||
                                must match
 | 
			
		||||
        :param include_patterns: list of regex patterns the commands
 | 
			
		||||
                                 must include
 | 
			
		||||
        :param exclude_patterns: list of regex patterns the commands
 | 
			
		||||
                                 must exclude
 | 
			
		||||
        """
 | 
			
		||||
        self.command_1 = split(command_1)
 | 
			
		||||
        self.command_2 = split(command_2)
 | 
			
		||||
        self.fuzzyness = fuzzyness
 | 
			
		||||
        self.begin_similarly = begin_similarly
 | 
			
		||||
        self.include_patterns = include_patterns
 | 
			
		||||
        self.exclude_patterns = exclude_patterns
 | 
			
		||||
 | 
			
		||||
    def __bool__(self):
 | 
			
		||||
        if self.begin_similarly:
 | 
			
		||||
            if not self.beginnings_are_equal:
 | 
			
		||||
                return False
 | 
			
		||||
 | 
			
		||||
        if self.include_patterns is not None:
 | 
			
		||||
            if not self.commands_contain_include_patterns:
 | 
			
		||||
                return False
 | 
			
		||||
 | 
			
		||||
        if self.exclude_patterns is not None:
 | 
			
		||||
            if not self.commands_contain_no_exclude_patterns:
 | 
			
		||||
                return False
 | 
			
		||||
 | 
			
		||||
        return self.similarity >= self.fuzzyness
 | 
			
		||||
 | 
			
		||||
    @lazy_property
 | 
			
		||||
    def beginnings_are_equal(self):
 | 
			
		||||
        return self.command_1[0] == self.command_2[0]
 | 
			
		||||
 | 
			
		||||
    @lazy_property
 | 
			
		||||
    def commands_contain_include_patterns(self):
 | 
			
		||||
        return all((
 | 
			
		||||
            self.contains_regex_patterns(self.command_1, self.include_patterns),
 | 
			
		||||
            self.contains_regex_patterns(self.command_2, self.include_patterns)
 | 
			
		||||
        ))
 | 
			
		||||
 | 
			
		||||
    @lazy_property
 | 
			
		||||
    def commands_contain_no_exclude_patterns(self):
 | 
			
		||||
        return all((
 | 
			
		||||
            not self.contains_regex_patterns(self.command_1, self.exclude_patterns),
 | 
			
		||||
            not self.contains_regex_patterns(self.command_2, self.exclude_patterns)
 | 
			
		||||
        ))
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def contains_regex_patterns(command, regex_parts):
 | 
			
		||||
        command = ' '.join(command)
 | 
			
		||||
        for pattern in regex_parts:
 | 
			
		||||
            if not search(pattern, command):
 | 
			
		||||
                return False
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    @lazy_property
 | 
			
		||||
    def similarity(self):
 | 
			
		||||
        parts_1 = set(self.command_1)
 | 
			
		||||
        parts_2 = set(self.command_2)
 | 
			
		||||
 | 
			
		||||
        difference = parts_1 - parts_2
 | 
			
		||||
        deviance = len(difference) / len(max(parts_1, parts_2))
 | 
			
		||||
        return 1 - deviance
 | 
			
		||||
							
								
								
									
										97
									
								
								tfw/components/terminal/history_monitor.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										97
									
								
								tfw/components/terminal/history_monitor.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,97 @@
 | 
			
		||||
from re import findall
 | 
			
		||||
from re import compile as compileregex
 | 
			
		||||
from abc import ABC, abstractmethod
 | 
			
		||||
 | 
			
		||||
from tfw.internals.inotify import InotifyObserver
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HistoryMonitor(ABC, InotifyObserver):
 | 
			
		||||
    """
 | 
			
		||||
    Abstract class capable of monitoring and parsing a history file such as
 | 
			
		||||
    bash HISTFILEs. Monitoring means detecting when the file was changed and
 | 
			
		||||
    notifying subscribers about new content in the file.
 | 
			
		||||
 | 
			
		||||
    This is useful for monitoring CLI sessions.
 | 
			
		||||
 | 
			
		||||
    To specify a custom HistoryMonitor inherit from this class and override the
 | 
			
		||||
    command pattern property and optionally the sanitize_command method.
 | 
			
		||||
    See examples below.
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, uplink, histfile):
 | 
			
		||||
        self.histfile = histfile
 | 
			
		||||
        self.history = []
 | 
			
		||||
        self._last_length = len(self.history)
 | 
			
		||||
        self.uplink = uplink
 | 
			
		||||
        super().__init__(self.histfile)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def domain(self):
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def on_modified(self, event):
 | 
			
		||||
        self._fetch_history()
 | 
			
		||||
        if self._last_length < len(self.history):
 | 
			
		||||
            for command in self.history[self._last_length:]:
 | 
			
		||||
                self.send_message(command)
 | 
			
		||||
 | 
			
		||||
    def _fetch_history(self):
 | 
			
		||||
        self._last_length = len(self.history)
 | 
			
		||||
        with open(self.histfile, 'r') as ifile:
 | 
			
		||||
            pattern = compileregex(self.command_pattern)
 | 
			
		||||
            data = ifile.read()
 | 
			
		||||
            self.history = [
 | 
			
		||||
                self.sanitize_command(command)
 | 
			
		||||
                for command in findall(pattern, data)
 | 
			
		||||
            ]
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def command_pattern(self):
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    def sanitize_command(self, command):
 | 
			
		||||
        # pylint: disable=no-self-use
 | 
			
		||||
        return command
 | 
			
		||||
 | 
			
		||||
    def send_message(self, command):
 | 
			
		||||
        self.uplink.send_message({
 | 
			
		||||
            'key': f'history.{self.domain}',
 | 
			
		||||
            'value': command
 | 
			
		||||
        })
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BashMonitor(HistoryMonitor):
 | 
			
		||||
    """
 | 
			
		||||
    HistoryMonitor for monitoring bash CLI sessions.
 | 
			
		||||
    This requires the following to be set in bash
 | 
			
		||||
    (note that this is done automatically by TFW):
 | 
			
		||||
    PROMPT_COMMAND="history -a"
 | 
			
		||||
    shopt -s cmdhist
 | 
			
		||||
    shopt -s histappend
 | 
			
		||||
    unset HISTCONTROL
 | 
			
		||||
    """
 | 
			
		||||
    @property
 | 
			
		||||
    def domain(self):
 | 
			
		||||
        return 'bash'
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def command_pattern(self):
 | 
			
		||||
        return r'.+'
 | 
			
		||||
 | 
			
		||||
    def sanitize_command(self, command):
 | 
			
		||||
        return command.strip()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GDBMonitor(HistoryMonitor):
 | 
			
		||||
    """
 | 
			
		||||
    HistoryMonitor to monitor GDB sessions.
 | 
			
		||||
    For this to work "set trace-commands on" must be set in GDB.
 | 
			
		||||
    """
 | 
			
		||||
    @property
 | 
			
		||||
    def domain(self):
 | 
			
		||||
        return 'gdb'
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def command_pattern(self):
 | 
			
		||||
        return r'(?<=\n)\+(.+)\n'
 | 
			
		||||
							
								
								
									
										44
									
								
								tfw/components/terminal/terminado_mini_server.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								tfw/components/terminal/terminado_mini_server.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,44 @@
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
from tornado.web import Application
 | 
			
		||||
from terminado import TermSocket, SingleTermManager
 | 
			
		||||
 | 
			
		||||
LOG = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TerminadoMiniServer:
 | 
			
		||||
    def __init__(self, url, port, workdir, shellcmd):
 | 
			
		||||
        self.port = port
 | 
			
		||||
        self._term_manager = SingleTermManager(
 | 
			
		||||
            shell_command=shellcmd,
 | 
			
		||||
            term_settings={'cwd': workdir}
 | 
			
		||||
        )
 | 
			
		||||
        self.application = Application([(
 | 
			
		||||
            url,
 | 
			
		||||
            TerminadoMiniServer.ResetterTermSocket,
 | 
			
		||||
            {'term_manager': self._term_manager}
 | 
			
		||||
        )])
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def term_manager(self):
 | 
			
		||||
        return self._term_manager
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def pty(self):
 | 
			
		||||
        if self.term_manager.terminal is None:
 | 
			
		||||
            self.term_manager.get_terminal()
 | 
			
		||||
        return self.term_manager.terminal.ptyproc
 | 
			
		||||
 | 
			
		||||
    class ResetterTermSocket(TermSocket): # pylint: disable=abstract-method
 | 
			
		||||
        def check_origin(self, origin):
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
        def on_close(self):
 | 
			
		||||
            self.term_manager.terminal = None
 | 
			
		||||
            self.term_manager.get_terminal()
 | 
			
		||||
 | 
			
		||||
    def listen(self):
 | 
			
		||||
        self.application.listen(self.port)
 | 
			
		||||
 | 
			
		||||
    def stop(self):
 | 
			
		||||
        self.term_manager.shutdown()
 | 
			
		||||
							
								
								
									
										68
									
								
								tfw/components/terminal/terminal_commands.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										68
									
								
								tfw/components/terminal/terminal_commands.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,68 @@
 | 
			
		||||
import logging
 | 
			
		||||
from abc import ABC
 | 
			
		||||
from re import match
 | 
			
		||||
from shlex import split
 | 
			
		||||
 | 
			
		||||
LOG = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TerminalCommands(ABC):
 | 
			
		||||
    # pylint: disable=anomalous-backslash-in-string
 | 
			
		||||
    """
 | 
			
		||||
    A class you can use to define hooks for terminal commands. This means that you can
 | 
			
		||||
    have python code executed when the user enters a specific command to the terminal on
 | 
			
		||||
    our frontend.
 | 
			
		||||
 | 
			
		||||
    To receive events you need to subscribe TerminalCommand.callback to a HistoryMonitor
 | 
			
		||||
    instance.
 | 
			
		||||
 | 
			
		||||
    Inherit from this class and define methods which start with "command\_". When the user
 | 
			
		||||
    executes the command specified after the underscore, your method will be invoked. All
 | 
			
		||||
    such commands must expect the parameter \*args which will contain the arguments of the
 | 
			
		||||
    command.
 | 
			
		||||
 | 
			
		||||
    For example to define a method that runs when someone starts vim in the terminal
 | 
			
		||||
    you have to define a method like: "def command_vim(self, \*args)"
 | 
			
		||||
 | 
			
		||||
    You can also use this class to create new commands similarly.
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, bashrc):
 | 
			
		||||
        self._command_method_regex = r'^command_(.+)$'
 | 
			
		||||
        self.command_implemetations = self._build_command_to_implementation_dict()
 | 
			
		||||
        if bashrc is not None:
 | 
			
		||||
            self._setup_bashrc_aliases(bashrc)
 | 
			
		||||
 | 
			
		||||
    def _build_command_to_implementation_dict(self):
 | 
			
		||||
        return {
 | 
			
		||||
            self._parse_command_name(fun): getattr(self, fun)
 | 
			
		||||
            for fun in dir(self)
 | 
			
		||||
            if callable(getattr(self, fun))
 | 
			
		||||
            and self._is_command_implementation(fun)
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    def _setup_bashrc_aliases(self, bashrc):
 | 
			
		||||
        with open(bashrc, 'a') as ofile:
 | 
			
		||||
            alias_template = 'type {0} &> /dev/null || alias {0}="{0} &> /dev/null"\n'
 | 
			
		||||
            for command in self.command_implemetations.keys():
 | 
			
		||||
                ofile.write(alias_template.format(command))
 | 
			
		||||
 | 
			
		||||
    def _is_command_implementation(self, method_name):
 | 
			
		||||
        return bool(self._match_command_regex(method_name))
 | 
			
		||||
 | 
			
		||||
    def _parse_command_name(self, method_name):
 | 
			
		||||
        try:
 | 
			
		||||
            return self._match_command_regex(method_name).groups()[0]
 | 
			
		||||
        except AttributeError:
 | 
			
		||||
            return ''
 | 
			
		||||
 | 
			
		||||
    def _match_command_regex(self, string):
 | 
			
		||||
        return match(self._command_method_regex, string)
 | 
			
		||||
 | 
			
		||||
    def callback(self, command):
 | 
			
		||||
        parts = split(command)
 | 
			
		||||
        command = parts[0]
 | 
			
		||||
        if command in self.command_implemetations.keys():
 | 
			
		||||
            try:
 | 
			
		||||
                self.command_implemetations[command](*parts[1:])
 | 
			
		||||
            except Exception: # pylint: disable=broad-except
 | 
			
		||||
                LOG.exception('Command "%s" failed:', command)
 | 
			
		||||
							
								
								
									
										9
									
								
								tfw/components/terminal/terminal_commands_handler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								tfw/components/terminal/terminal_commands_handler.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,9 @@
 | 
			
		||||
from .terminal_commands import TerminalCommands
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TerminalCommandsHandler(TerminalCommands):
 | 
			
		||||
    keys = ['history.bash']
 | 
			
		||||
 | 
			
		||||
    def handle_event(self, message, _):
 | 
			
		||||
        command = message['value']
 | 
			
		||||
        self.callback(command)
 | 
			
		||||
							
								
								
									
										86
									
								
								tfw/components/terminal/terminal_handler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								tfw/components/terminal/terminal_handler.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,86 @@
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
from .history_monitor import BashMonitor
 | 
			
		||||
from .terminado_mini_server import TerminadoMiniServer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
LOG = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TerminalHandler:
 | 
			
		||||
    keys = ['shell']
 | 
			
		||||
    """
 | 
			
		||||
    Event handler responsible for managing terminal sessions for frontend xterm
 | 
			
		||||
    sessions to connect to. You need to instanciate this in order for frontend
 | 
			
		||||
    terminals to work.
 | 
			
		||||
 | 
			
		||||
    This EventHandler accepts messages that have a data['command'] key specifying
 | 
			
		||||
    a command to be executed.
 | 
			
		||||
    The API of each command is documented in their respective handler.
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, *, port, user, workind_directory, histfile):
 | 
			
		||||
        """
 | 
			
		||||
        :param key: key this EventHandler listens to
 | 
			
		||||
        :param monitor: tfw.components.HistoryMonitor instance to read command history from
 | 
			
		||||
        """
 | 
			
		||||
        self.server_connector = None
 | 
			
		||||
        self._histfile = histfile
 | 
			
		||||
        self._historymonitor = None
 | 
			
		||||
        bash_as_user_cmd = ['sudo', '-u', user, 'bash']
 | 
			
		||||
 | 
			
		||||
        self.terminado_server = TerminadoMiniServer(
 | 
			
		||||
            '/terminal',
 | 
			
		||||
            port,
 | 
			
		||||
            workind_directory,
 | 
			
		||||
            bash_as_user_cmd
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.commands = {
 | 
			
		||||
            'write': self.write,
 | 
			
		||||
            'read': self.read
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        self.terminado_server.listen()
 | 
			
		||||
 | 
			
		||||
    def start(self):
 | 
			
		||||
        self._historymonitor = BashMonitor(self.server_connector, self._histfile)
 | 
			
		||||
        self._historymonitor.start()
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def historymonitor(self):
 | 
			
		||||
        return self._historymonitor
 | 
			
		||||
 | 
			
		||||
    def handle_event(self, message, _):
 | 
			
		||||
        try:
 | 
			
		||||
            data = message['data']
 | 
			
		||||
            message['data'] = self.commands[data['command']](data)
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            LOG.error('IGNORING MESSAGE: Invalid message received: %s', message)
 | 
			
		||||
 | 
			
		||||
    def write(self, data):
 | 
			
		||||
        """
 | 
			
		||||
        Writes a string to the terminal session (on the pty level).
 | 
			
		||||
        Useful for pre-typing and executing commands for the user.
 | 
			
		||||
 | 
			
		||||
        :param data: TFW message data containing 'value'
 | 
			
		||||
                     (command to be written to the pty)
 | 
			
		||||
        """
 | 
			
		||||
        self.terminado_server.pty.write(data['value'])
 | 
			
		||||
        return data
 | 
			
		||||
 | 
			
		||||
    def read(self, data):
 | 
			
		||||
        """
 | 
			
		||||
        Reads the history of commands executed.
 | 
			
		||||
 | 
			
		||||
        :param data: TFW message data containing 'count'
 | 
			
		||||
                     (the number of history elements to return)
 | 
			
		||||
        :return dict: message with list of commands in data['history']
 | 
			
		||||
        """
 | 
			
		||||
        data['count'] = int(data.get('count', 1))
 | 
			
		||||
        if self.historymonitor:
 | 
			
		||||
            data['history'] = self.historymonitor.history[-data['count']:]
 | 
			
		||||
        return data
 | 
			
		||||
 | 
			
		||||
    def cleanup(self):
 | 
			
		||||
        self.terminado_server.stop()
 | 
			
		||||
        self.historymonitor.stop()
 | 
			
		||||
							
								
								
									
										1
									
								
								tfw/config/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								tfw/config/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1 @@
 | 
			
		||||
from .envvars import TFWENV, TAOENV
 | 
			
		||||
							
								
								
									
										5
									
								
								tfw/config/envvars.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								tfw/config/envvars.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,5 @@
 | 
			
		||||
from .lazy_environment import LazyEnvironment
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
TFWENV = LazyEnvironment('TFW_', 'tfwenvtuple').environment
 | 
			
		||||
TAOENV = LazyEnvironment('AVATAO_', 'taoenvtuple').environment
 | 
			
		||||
							
								
								
									
										22
									
								
								tfw/config/lazy_environment.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								tfw/config/lazy_environment.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,22 @@
 | 
			
		||||
from collections import namedtuple
 | 
			
		||||
from os import environ
 | 
			
		||||
 | 
			
		||||
from tfw.internals.lazy import lazy_property
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LazyEnvironment:
 | 
			
		||||
    def __init__(self, prefix, tuple_name):
 | 
			
		||||
        self._prefix = prefix
 | 
			
		||||
        self._tuple_name = tuple_name
 | 
			
		||||
 | 
			
		||||
    @lazy_property
 | 
			
		||||
    def environment(self):
 | 
			
		||||
        return self.prefixed_envvars_to_namedtuple()
 | 
			
		||||
 | 
			
		||||
    def prefixed_envvars_to_namedtuple(self):
 | 
			
		||||
        envvars = {
 | 
			
		||||
            envvar.replace(self._prefix, '', 1): environ.get(envvar)
 | 
			
		||||
            for envvar in environ.keys()
 | 
			
		||||
            if envvar.startswith(self._prefix)
 | 
			
		||||
        }
 | 
			
		||||
        return namedtuple(self._tuple_name, envvars)(**envvars)
 | 
			
		||||
							
								
								
									
										2
									
								
								tfw/event_handlers.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								tfw/event_handlers.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,2 @@
 | 
			
		||||
# pylint: disable=unused-import
 | 
			
		||||
from tfw.internals.event_handling import EventHandler, FSMAwareEventHandler
 | 
			
		||||
							
								
								
									
										3
									
								
								tfw/fsm/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								tfw/fsm/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,3 @@
 | 
			
		||||
from .fsm_base import FSMBase
 | 
			
		||||
from .linear_fsm import LinearFSM
 | 
			
		||||
from .yaml_fsm import YamlFSM
 | 
			
		||||
							
								
								
									
										84
									
								
								tfw/fsm/fsm_base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										84
									
								
								tfw/fsm/fsm_base.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,84 @@
 | 
			
		||||
import logging
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from datetime import datetime
 | 
			
		||||
 | 
			
		||||
from transitions import Machine, MachineError
 | 
			
		||||
 | 
			
		||||
from tfw.internals.callback_mixin import CallbackMixin
 | 
			
		||||
 | 
			
		||||
LOG = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FSMBase(Machine, CallbackMixin):
 | 
			
		||||
    """
 | 
			
		||||
    A general FSM base class you can inherit from to track user progress.
 | 
			
		||||
    See linear_fsm.py for an example use-case.
 | 
			
		||||
    TFW uses the transitions library for state machines, please refer to their
 | 
			
		||||
    documentation for more information on creating your own machines:
 | 
			
		||||
    https://github.com/pytransitions/transitions
 | 
			
		||||
    """
 | 
			
		||||
    states, transitions = [], []
 | 
			
		||||
 | 
			
		||||
    def __init__(self, initial=None, accepted_states=None):
 | 
			
		||||
        """
 | 
			
		||||
        :param initial: which state to begin with, defaults to the last one
 | 
			
		||||
        :param accepted_states: list of states in which the challenge should be
 | 
			
		||||
                                considered successfully completed
 | 
			
		||||
        """
 | 
			
		||||
        self.accepted_states = accepted_states or [self.states[-1].name]
 | 
			
		||||
        self.trigger_predicates = defaultdict(list)
 | 
			
		||||
        self.event_log = []
 | 
			
		||||
 | 
			
		||||
        Machine.__init__(
 | 
			
		||||
            self,
 | 
			
		||||
            states=self.states,
 | 
			
		||||
            transitions=self.transitions,
 | 
			
		||||
            initial=initial or self.states[0],
 | 
			
		||||
            send_event=True,
 | 
			
		||||
            ignore_invalid_triggers=True,
 | 
			
		||||
            after_state_change='execute_callbacks'
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def execute_callbacks(self, event_data):
 | 
			
		||||
        self._execute_callbacks(event_data.kwargs)
 | 
			
		||||
 | 
			
		||||
    def is_solved(self):
 | 
			
		||||
        return self.state in self.accepted_states  # pylint: disable=no-member
 | 
			
		||||
 | 
			
		||||
    def subscribe_predicate(self, trigger, *predicates):
 | 
			
		||||
        self.trigger_predicates[trigger].extend(predicates)
 | 
			
		||||
 | 
			
		||||
    def unsubscribe_predicate(self, trigger, *predicates):
 | 
			
		||||
        self.trigger_predicates[trigger] = [
 | 
			
		||||
            predicate
 | 
			
		||||
            for predicate in self.trigger_predicates[trigger]
 | 
			
		||||
            not in predicates
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    def step(self, trigger):
 | 
			
		||||
        predicate_results = (
 | 
			
		||||
            predicate()
 | 
			
		||||
            for predicate in self.trigger_predicates[trigger]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if all(predicate_results):
 | 
			
		||||
            try:
 | 
			
		||||
                from_state = self.state
 | 
			
		||||
                self.trigger(trigger)
 | 
			
		||||
                self.update_event_log(from_state, trigger)
 | 
			
		||||
                return True
 | 
			
		||||
            except (AttributeError, MachineError):
 | 
			
		||||
                LOG.debug('FSM failed to execute nonexistent trigger: "%s"', trigger)
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    def update_event_log(self, from_state, trigger):
 | 
			
		||||
        self.event_log.append({
 | 
			
		||||
            'from_state': from_state,
 | 
			
		||||
            'to_state': self.state,
 | 
			
		||||
            'trigger': trigger,
 | 
			
		||||
            'timestamp': datetime.utcnow()
 | 
			
		||||
        })
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def in_accepted_state(self):
 | 
			
		||||
        return self.state in self.accepted_states
 | 
			
		||||
							
								
								
									
										32
									
								
								tfw/fsm/linear_fsm.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								tfw/fsm/linear_fsm.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,32 @@
 | 
			
		||||
from transitions import State
 | 
			
		||||
 | 
			
		||||
from .fsm_base import FSMBase
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LinearFSM(FSMBase):
 | 
			
		||||
    # pylint: disable=anomalous-backslash-in-string
 | 
			
		||||
    """
 | 
			
		||||
    This is a state machine for challenges with linear progression, consisting of
 | 
			
		||||
    a number of steps specified in the constructor. It automatically sets up 2
 | 
			
		||||
    actions (triggers) between states as such:
 | 
			
		||||
    (0) --  step_1     --> (1) --  step_2     --> (2) --  step_3     --> (3) ...
 | 
			
		||||
    (0) --  step_next  --> (1) --  step_next  --> (2) --  step_next  --> (3) ...
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, number_of_steps):
 | 
			
		||||
        """
 | 
			
		||||
        :param number_of_steps: how many states this FSM should have
 | 
			
		||||
        """
 | 
			
		||||
        self.states = [State(name=str(index)) for index in range(number_of_steps)]
 | 
			
		||||
        self.transitions = []
 | 
			
		||||
        for state in self.states[:-1]:
 | 
			
		||||
            self.transitions.append({
 | 
			
		||||
                'trigger': f'step_{int(state.name)+1}',
 | 
			
		||||
                'source': state.name,
 | 
			
		||||
                'dest': str(int(state.name)+1)
 | 
			
		||||
            })
 | 
			
		||||
            self.transitions.append({
 | 
			
		||||
                'trigger': 'step_next',
 | 
			
		||||
                'source': state.name,
 | 
			
		||||
                'dest': str(int(state.name)+1)
 | 
			
		||||
            })
 | 
			
		||||
        super(LinearFSM, self).__init__()
 | 
			
		||||
							
								
								
									
										105
									
								
								tfw/fsm/yaml_fsm.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										105
									
								
								tfw/fsm/yaml_fsm.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,105 @@
 | 
			
		||||
from subprocess import Popen, run
 | 
			
		||||
from functools import partial, singledispatch
 | 
			
		||||
from contextlib import suppress
 | 
			
		||||
 | 
			
		||||
import yaml
 | 
			
		||||
import jinja2
 | 
			
		||||
from transitions import State
 | 
			
		||||
 | 
			
		||||
from .fsm_base import FSMBase
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class YamlFSM(FSMBase):
 | 
			
		||||
    """
 | 
			
		||||
    This is a state machine capable of building itself from a YAML config file.
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, config_file, jinja2_variables=None):
 | 
			
		||||
        """
 | 
			
		||||
        :param config_file: path of the YAML file
 | 
			
		||||
        :param jinja2_variables: dict containing jinja2 variables
 | 
			
		||||
                                 or str with filename of YAML file to
 | 
			
		||||
                                 parse and use as dict.
 | 
			
		||||
                                 jinja2 support is disabled if this is None
 | 
			
		||||
        """
 | 
			
		||||
        self.config = ConfigParser(config_file, jinja2_variables).config
 | 
			
		||||
        self.setup_states()
 | 
			
		||||
        super().__init__()  # FSMBase.__init__() requires states
 | 
			
		||||
        self.setup_transitions()
 | 
			
		||||
 | 
			
		||||
    def setup_states(self):
 | 
			
		||||
        self.for_config_states_and_transitions_do(self.wrap_callbacks_with_subprocess_call)
 | 
			
		||||
        self.states = [State(**state) for state in self.config['states']]
 | 
			
		||||
 | 
			
		||||
    def setup_transitions(self):
 | 
			
		||||
        self.for_config_states_and_transitions_do(self.subscribe_and_remove_predicates)
 | 
			
		||||
        for transition in self.config['transitions']:
 | 
			
		||||
            self.add_transition(**transition)
 | 
			
		||||
 | 
			
		||||
    def for_config_states_and_transitions_do(self, what):
 | 
			
		||||
        for array in ('states', 'transitions'):
 | 
			
		||||
            for json_obj in self.config[array]:
 | 
			
		||||
                what(json_obj)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def wrap_callbacks_with_subprocess_call(json_obj):
 | 
			
		||||
        topatch = ('on_enter', 'on_exit', 'prepare', 'before', 'after')
 | 
			
		||||
        for key in json_obj:
 | 
			
		||||
            if key in topatch:
 | 
			
		||||
                json_obj[key] = partial(run_command_async, json_obj[key])
 | 
			
		||||
 | 
			
		||||
    def subscribe_and_remove_predicates(self, json_obj):
 | 
			
		||||
        if 'predicates' in json_obj:
 | 
			
		||||
            for predicate in json_obj['predicates']:
 | 
			
		||||
                self.subscribe_predicate(
 | 
			
		||||
                    json_obj['trigger'],
 | 
			
		||||
                    partial(
 | 
			
		||||
                        command_statuscode_is_zero,
 | 
			
		||||
                        predicate
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        with suppress(KeyError):
 | 
			
		||||
            json_obj.pop('predicates')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_command_async(command, _):
 | 
			
		||||
    Popen(command, shell=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def command_statuscode_is_zero(command):
 | 
			
		||||
    return run(command, shell=True).returncode == 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ConfigParser:
 | 
			
		||||
    def __init__(self, config_file, jinja2_variables):
 | 
			
		||||
        self.read_variables = singledispatch(self._read_variables)
 | 
			
		||||
        self.read_variables.register(dict, self._read_variables_dict)
 | 
			
		||||
        self.read_variables.register(str, self._read_variables_str)
 | 
			
		||||
 | 
			
		||||
        self.config = self.parse_config(config_file, jinja2_variables)
 | 
			
		||||
 | 
			
		||||
    def parse_config(self, config_file, jinja2_variables):
 | 
			
		||||
        config_string = self.read_file(config_file)
 | 
			
		||||
        if jinja2_variables is not None:
 | 
			
		||||
            variables = self.read_variables(jinja2_variables)
 | 
			
		||||
            template = jinja2.Environment(loader=jinja2.BaseLoader).from_string(config_string)
 | 
			
		||||
            config_string = template.render(**variables)
 | 
			
		||||
        return yaml.safe_load(config_string)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def read_file(filename):
 | 
			
		||||
        with open(filename, 'r') as ifile:
 | 
			
		||||
            return ifile.read()
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _read_variables(variables):
 | 
			
		||||
        raise TypeError(f'Invalid variables type {type(variables)}')
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _read_variables_str(variables):
 | 
			
		||||
        with open(variables, 'r') as ifile:
 | 
			
		||||
            return yaml.safe_load(ifile)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _read_variables_dict(variables):
 | 
			
		||||
        return variables
 | 
			
		||||
							
								
								
									
										0
									
								
								tfw/internals/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								tfw/internals/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										35
									
								
								tfw/internals/callback_mixin.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								tfw/internals/callback_mixin.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,35 @@
 | 
			
		||||
from functools import partial
 | 
			
		||||
 | 
			
		||||
from .lazy import lazy_property
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CallbackMixin:
 | 
			
		||||
    @lazy_property
 | 
			
		||||
    def _callbacks(self):
 | 
			
		||||
        # pylint: disable=no-self-use
 | 
			
		||||
        return []
 | 
			
		||||
 | 
			
		||||
    def subscribe_callback(self, callback, *args, **kwargs):
 | 
			
		||||
        """
 | 
			
		||||
        Subscribe a callable to invoke once an event is triggered.
 | 
			
		||||
        :param callback: callable to be executed on events
 | 
			
		||||
        :param args: arguments passed to callable
 | 
			
		||||
        :param kwargs: kwargs passed to callable
 | 
			
		||||
        """
 | 
			
		||||
        fun = partial(callback, *args, **kwargs)
 | 
			
		||||
        self._callbacks.append(fun)
 | 
			
		||||
 | 
			
		||||
    def subscribe_callbacks(self, *callbacks):
 | 
			
		||||
        """
 | 
			
		||||
        Subscribe a list of callbacks to incoke once an event is triggered.
 | 
			
		||||
        :param callbacks: callbacks to be subscribed
 | 
			
		||||
        """
 | 
			
		||||
        for callback in callbacks:
 | 
			
		||||
            self.subscribe_callback(callback)
 | 
			
		||||
 | 
			
		||||
    def unsubscribe_callback(self, callback):
 | 
			
		||||
        self._callbacks.remove(callback)
 | 
			
		||||
 | 
			
		||||
    def _execute_callbacks(self, *args, **kwargs):
 | 
			
		||||
        for callback in self._callbacks:
 | 
			
		||||
            callback(*args, **kwargs)
 | 
			
		||||
							
								
								
									
										107
									
								
								tfw/internals/crypto.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										107
									
								
								tfw/internals/crypto.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,107 @@
 | 
			
		||||
from functools import wraps
 | 
			
		||||
from base64 import b64encode, b64decode
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
from hashlib import md5
 | 
			
		||||
from os import urandom, chmod
 | 
			
		||||
from os.path import exists
 | 
			
		||||
from stat import S_IRUSR, S_IWUSR, S_IXUSR
 | 
			
		||||
 | 
			
		||||
from cryptography.hazmat.backends import default_backend
 | 
			
		||||
from cryptography.hazmat.primitives.hashes import SHA256
 | 
			
		||||
from cryptography.hazmat.primitives.hmac import HMAC as _HMAC
 | 
			
		||||
from cryptography.exceptions import InvalidSignature
 | 
			
		||||
 | 
			
		||||
from tfw.internals.networking import message_bytes
 | 
			
		||||
from tfw.internals.lazy import lazy_property
 | 
			
		||||
from tfw.config import TFWENV
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def message_checksum(message):
 | 
			
		||||
    return md5(message_bytes(message)).hexdigest()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def sign_message(key, message):
 | 
			
		||||
    message.pop('scope', None)
 | 
			
		||||
    message.pop('signature', None)
 | 
			
		||||
    signature = message_signature(key, message)
 | 
			
		||||
    message['signature'] = b64encode(signature).decode()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def message_signature(key, message):
 | 
			
		||||
    return HMAC(key, message_bytes(message)).signature
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def verify_message(key, message):
 | 
			
		||||
    message.pop('scope', None)
 | 
			
		||||
    message = deepcopy(message)
 | 
			
		||||
    try:
 | 
			
		||||
        signature_b64 = message.pop('signature')
 | 
			
		||||
        signature = b64decode(signature_b64)
 | 
			
		||||
        actual_signature = message_signature(key, message)
 | 
			
		||||
        return signature == actual_signature
 | 
			
		||||
    except KeyError:
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class KeyManager:
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.keyfile = TFWENV.AUTH_KEY
 | 
			
		||||
        if not exists(self.keyfile):
 | 
			
		||||
            self._init_auth_key()
 | 
			
		||||
 | 
			
		||||
    @lazy_property
 | 
			
		||||
    def auth_key(self):
 | 
			
		||||
        with open(self.keyfile, 'rb') as ifile:
 | 
			
		||||
            return ifile.read()
 | 
			
		||||
 | 
			
		||||
    def _init_auth_key(self):
 | 
			
		||||
        key = self.generate_key()
 | 
			
		||||
        with open(self.keyfile, 'wb') as ofile:
 | 
			
		||||
            ofile.write(key)
 | 
			
		||||
        self._chmod_700_keyfile()
 | 
			
		||||
        return key
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def generate_key():
 | 
			
		||||
        return urandom(32)
 | 
			
		||||
 | 
			
		||||
    def _chmod_700_keyfile(self):
 | 
			
		||||
        chmod(self.keyfile, S_IRUSR | S_IWUSR | S_IXUSR)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HMAC:
 | 
			
		||||
    def __init__(self, key, message):
 | 
			
		||||
        self.key = key
 | 
			
		||||
        self.message = message
 | 
			
		||||
        self._hmac = _HMAC(
 | 
			
		||||
            key=key,
 | 
			
		||||
            algorithm=SHA256(),
 | 
			
		||||
            backend=default_backend()
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _reload_if_finalized(f):
 | 
			
		||||
        # pylint: disable=no-self-argument,not-callable
 | 
			
		||||
        @wraps(f)
 | 
			
		||||
        def wrapped(instance, *args, **kwargs):
 | 
			
		||||
            if getattr(instance, '_finalized', False):
 | 
			
		||||
                instance.__init__(instance.key, instance.message)
 | 
			
		||||
            ret_val = f(instance, *args, **kwargs)
 | 
			
		||||
            setattr(instance, '_finalized', True)
 | 
			
		||||
            return ret_val
 | 
			
		||||
        return wrapped
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    @_reload_if_finalized
 | 
			
		||||
    def signature(self):
 | 
			
		||||
        self._hmac.update(self.message)
 | 
			
		||||
        signature = self._hmac.finalize()
 | 
			
		||||
        return signature
 | 
			
		||||
 | 
			
		||||
    @_reload_if_finalized
 | 
			
		||||
    def verify(self, signature):
 | 
			
		||||
        self._hmac.update(self.message)
 | 
			
		||||
        try:
 | 
			
		||||
            self._hmac.verify(signature)
 | 
			
		||||
            return True
 | 
			
		||||
        except InvalidSignature:
 | 
			
		||||
            return False
 | 
			
		||||
							
								
								
									
										3
									
								
								tfw/internals/event_handling/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								tfw/internals/event_handling/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,3 @@
 | 
			
		||||
from .event_handler_factory_base import EventHandlerFactoryBase
 | 
			
		||||
from .event_handler import EventHandler
 | 
			
		||||
from .fsm_aware_event_handler import FSMAwareEventHandler
 | 
			
		||||
							
								
								
									
										27
									
								
								tfw/internals/event_handling/event_handler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								tfw/internals/event_handling/event_handler.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,27 @@
 | 
			
		||||
class EventHandler:
 | 
			
		||||
    _instances = set()
 | 
			
		||||
 | 
			
		||||
    def __init__(self, server_connector):
 | 
			
		||||
        type(self)._instances.add(self)
 | 
			
		||||
        self.server_connector = server_connector
 | 
			
		||||
 | 
			
		||||
    def start(self):
 | 
			
		||||
        self.server_connector.register_callback(self._event_callback)
 | 
			
		||||
 | 
			
		||||
    def _event_callback(self, message):
 | 
			
		||||
        self.handle_event(message, self.server_connector)
 | 
			
		||||
 | 
			
		||||
    def handle_event(self, message, server_connector):
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def stop_all_instances(cls):
 | 
			
		||||
        for instance in cls._instances:
 | 
			
		||||
            instance.stop()
 | 
			
		||||
 | 
			
		||||
    def stop(self):
 | 
			
		||||
        self.server_connector.close()
 | 
			
		||||
        self.cleanup()
 | 
			
		||||
 | 
			
		||||
    def cleanup(self):
 | 
			
		||||
        pass
 | 
			
		||||
							
								
								
									
										68
									
								
								tfw/internals/event_handling/event_handler_factory_base.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										68
									
								
								tfw/internals/event_handling/event_handler_factory_base.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,68 @@
 | 
			
		||||
from contextlib import suppress
 | 
			
		||||
 | 
			
		||||
from .event_handler import EventHandler
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EventHandlerFactoryBase:
 | 
			
		||||
    def build(self, handler_stub, *, keys=None, event_handler_type=EventHandler):
 | 
			
		||||
        builder = EventHandlerBuilder(handler_stub, keys, event_handler_type)
 | 
			
		||||
        server_connector = self._build_server_connector()
 | 
			
		||||
        event_handler = builder.build(server_connector)
 | 
			
		||||
        handler_stub.server_connector = server_connector
 | 
			
		||||
        with suppress(AttributeError):
 | 
			
		||||
            handler_stub.start()
 | 
			
		||||
        event_handler.start()
 | 
			
		||||
        return event_handler
 | 
			
		||||
 | 
			
		||||
    def _build_server_connector(self):
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EventHandlerBuilder:
 | 
			
		||||
    def __init__(self, event_handler, supplied_keys, event_handler_type):
 | 
			
		||||
        self._analyzer = HandlerStubAnalyzer(event_handler, supplied_keys)
 | 
			
		||||
        self._event_handler_type = event_handler_type
 | 
			
		||||
 | 
			
		||||
    def build(self, server_connector):
 | 
			
		||||
        event_handler = self._event_handler_type(server_connector)
 | 
			
		||||
        server_connector.subscribe(*self._try_get_keys(event_handler))
 | 
			
		||||
        event_handler.handle_event = self._analyzer.handle_event
 | 
			
		||||
        with suppress(AttributeError):
 | 
			
		||||
            event_handler.cleanup = self._analyzer.cleanup
 | 
			
		||||
        return event_handler
 | 
			
		||||
 | 
			
		||||
    def _try_get_keys(self, event_handler):
 | 
			
		||||
        try:
 | 
			
		||||
            return self._analyzer.keys
 | 
			
		||||
        except ValueError:
 | 
			
		||||
            with suppress(AttributeError):
 | 
			
		||||
                return event_handler.keys
 | 
			
		||||
            raise
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HandlerStubAnalyzer:
 | 
			
		||||
    def __init__(self, event_handler, supplied_keys):
 | 
			
		||||
        self._event_handler = event_handler
 | 
			
		||||
        self._supplied_keys = supplied_keys
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def keys(self):
 | 
			
		||||
        if self._supplied_keys is None:
 | 
			
		||||
            try:
 | 
			
		||||
                return self._event_handler.keys
 | 
			
		||||
            except AttributeError:
 | 
			
		||||
                raise ValueError('No keys supplied!')
 | 
			
		||||
        return self._supplied_keys
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def handle_event(self):
 | 
			
		||||
        try:
 | 
			
		||||
            return self._event_handler.handle_event
 | 
			
		||||
        except AttributeError:
 | 
			
		||||
            if callable(self._event_handler):
 | 
			
		||||
                return self._event_handler
 | 
			
		||||
            raise ValueError('Object must implement handle_event or be a callable!')
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def cleanup(self):
 | 
			
		||||
        return self._event_handler.cleanup
 | 
			
		||||
							
								
								
									
										37
									
								
								tfw/internals/event_handling/fsm_aware.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								tfw/internals/event_handling/fsm_aware.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,37 @@
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
from tfw.internals.crypto import KeyManager, verify_message
 | 
			
		||||
 | 
			
		||||
LOG = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FSMAware:
 | 
			
		||||
    keys = ['fsm_update']
 | 
			
		||||
    """
 | 
			
		||||
    Base class for stuff that has to be aware of the framework FSM.
 | 
			
		||||
    This is done by processing 'fsm_update' messages.
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.fsm_state = None
 | 
			
		||||
        self.fsm_in_accepted_state = False
 | 
			
		||||
        self.fsm_event_log = []
 | 
			
		||||
        self._auth_key = KeyManager().auth_key
 | 
			
		||||
 | 
			
		||||
    def process_message(self, message):
 | 
			
		||||
        if message['key'] == 'fsm_update':
 | 
			
		||||
            if verify_message(self._auth_key, message):
 | 
			
		||||
                self._handle_fsm_update(message)
 | 
			
		||||
 | 
			
		||||
    def _handle_fsm_update(self, message):
 | 
			
		||||
        try:
 | 
			
		||||
            new_state = message['current_state']
 | 
			
		||||
            if self.fsm_state != new_state:
 | 
			
		||||
                self.handle_fsm_step(message)
 | 
			
		||||
            self.fsm_state = new_state
 | 
			
		||||
            self.fsm_in_accepted_state = message['in_accepted_state']
 | 
			
		||||
            self.fsm_event_log.append(message)
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            LOG.error('Invalid fsm_update message received!')
 | 
			
		||||
 | 
			
		||||
    def handle_fsm_step(self, message):
 | 
			
		||||
        pass
 | 
			
		||||
							
								
								
									
										19
									
								
								tfw/internals/event_handling/fsm_aware_event_handler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								tfw/internals/event_handling/fsm_aware_event_handler.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,19 @@
 | 
			
		||||
from .event_handler import EventHandler
 | 
			
		||||
from .fsm_aware import FSMAware
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FSMAwareEventHandler(EventHandler, FSMAware):
 | 
			
		||||
    # pylint: disable=abstract-method
 | 
			
		||||
    """
 | 
			
		||||
    Abstract base class for EventHandlers which automatically
 | 
			
		||||
    keep track of the state of the TFW FSM.
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, server_connector):
 | 
			
		||||
        EventHandler.__init__(self, server_connector)
 | 
			
		||||
        FSMAware.__init__(self)
 | 
			
		||||
 | 
			
		||||
    def _event_callback(self, message):
 | 
			
		||||
        self.process_message(message)
 | 
			
		||||
 | 
			
		||||
    def handle_fsm_step(self, message):
 | 
			
		||||
        self.handle_event(message, self.server_connector)
 | 
			
		||||
							
								
								
									
										190
									
								
								tfw/internals/event_handling/test_event_handler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										190
									
								
								tfw/internals/event_handling/test_event_handler.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,190 @@
 | 
			
		||||
# pylint: disable=redefined-outer-name,attribute-defined-outside-init
 | 
			
		||||
from secrets import token_urlsafe
 | 
			
		||||
from random import randint
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
from .event_handler_factory_base import EventHandlerFactoryBase
 | 
			
		||||
from .event_handler import EventHandler
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MockEventHandlerFactory(EventHandlerFactoryBase):
 | 
			
		||||
    def _build_server_connector(self):
 | 
			
		||||
        return MockServerConnector()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MockServerConnector:
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.keys = []
 | 
			
		||||
        self._on_message = None
 | 
			
		||||
 | 
			
		||||
    def simulate_message(self, message):
 | 
			
		||||
        self._on_message(message)
 | 
			
		||||
 | 
			
		||||
    def register_callback(self, callback):
 | 
			
		||||
        self._on_message = callback
 | 
			
		||||
 | 
			
		||||
    def subscribe(self, *keys):
 | 
			
		||||
        self.keys.extend(keys)
 | 
			
		||||
 | 
			
		||||
    def unsubscribe(self, *keys):
 | 
			
		||||
        for key in keys:
 | 
			
		||||
            self.keys.remove(key)
 | 
			
		||||
 | 
			
		||||
    def send_message(self, message, scope=None):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def close(self):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MockEventHandlerStub:
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.server_connector = None
 | 
			
		||||
        self.last_message = None
 | 
			
		||||
        self.cleaned_up = False
 | 
			
		||||
        self.started = False
 | 
			
		||||
 | 
			
		||||
    def start(self):
 | 
			
		||||
        self.started = True
 | 
			
		||||
 | 
			
		||||
    def cleanup(self):
 | 
			
		||||
        self.cleaned_up = True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MockEventHandler(MockEventHandlerStub):
 | 
			
		||||
    # pylint: disable=unused-argument
 | 
			
		||||
    def handle_event(self, message, server_connector):
 | 
			
		||||
        self.last_message = message
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MockCallable(MockEventHandlerStub):
 | 
			
		||||
    def __call__(self, message, server_connector):
 | 
			
		||||
        self.last_message = message
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
def test_msg():
 | 
			
		||||
    yield token_urlsafe(randint(16, 64))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
def test_keys():
 | 
			
		||||
    yield [
 | 
			
		||||
        token_urlsafe(randint(2, 8))
 | 
			
		||||
        for _ in range(randint(16, 32))
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_build_from_object(test_keys, test_msg):
 | 
			
		||||
    mock_eh = MockEventHandlerStub()
 | 
			
		||||
    def handle_event(message, server_connector):
 | 
			
		||||
        raise RuntimeError(message, server_connector.keys)
 | 
			
		||||
    mock_eh.handle_event = handle_event
 | 
			
		||||
 | 
			
		||||
    assert not mock_eh.started
 | 
			
		||||
    eh = MockEventHandlerFactory().build(mock_eh, keys=test_keys)
 | 
			
		||||
 | 
			
		||||
    assert mock_eh.started
 | 
			
		||||
    assert mock_eh.server_connector is eh.server_connector
 | 
			
		||||
    with pytest.raises(RuntimeError) as err:
 | 
			
		||||
        eh.server_connector.simulate_message(test_msg)
 | 
			
		||||
        msg, keys = err.args
 | 
			
		||||
        assert msg == test_msg
 | 
			
		||||
        assert keys == test_keys
 | 
			
		||||
    assert not mock_eh.cleaned_up
 | 
			
		||||
    eh.stop()
 | 
			
		||||
    assert mock_eh.cleaned_up
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_build_from_object_with_keys(test_keys, test_msg):
 | 
			
		||||
    mock_eh = MockEventHandler()
 | 
			
		||||
    mock_eh.keys = test_keys
 | 
			
		||||
 | 
			
		||||
    assert not mock_eh.started
 | 
			
		||||
    eh = MockEventHandlerFactory().build(mock_eh)
 | 
			
		||||
 | 
			
		||||
    assert mock_eh.server_connector.keys == test_keys
 | 
			
		||||
    assert eh.server_connector is mock_eh.server_connector
 | 
			
		||||
    assert mock_eh.started
 | 
			
		||||
    assert not mock_eh.last_message
 | 
			
		||||
    eh.server_connector.simulate_message(test_msg)
 | 
			
		||||
    assert mock_eh.last_message == test_msg
 | 
			
		||||
    assert not mock_eh.cleaned_up
 | 
			
		||||
    EventHandler.stop_all_instances()
 | 
			
		||||
    assert mock_eh.cleaned_up
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_build_from_simple_object(test_keys, test_msg):
 | 
			
		||||
    class SimpleMockEventHandler:
 | 
			
		||||
        # pylint: disable=no-self-use
 | 
			
		||||
        def handle_event(self, message, server_connector):
 | 
			
		||||
            raise RuntimeError(message, server_connector)
 | 
			
		||||
 | 
			
		||||
    mock_eh = SimpleMockEventHandler()
 | 
			
		||||
    eh = MockEventHandlerFactory().build(mock_eh, keys=test_keys)
 | 
			
		||||
 | 
			
		||||
    with pytest.raises(RuntimeError) as err:
 | 
			
		||||
        eh.server_connector.simulate_message(test_msg)
 | 
			
		||||
        msg, keys = err.args
 | 
			
		||||
        assert msg == test_msg
 | 
			
		||||
        assert keys == test_keys
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_build_from_callable(test_keys, test_msg):
 | 
			
		||||
    mock_eh = MockCallable()
 | 
			
		||||
 | 
			
		||||
    assert not mock_eh.started
 | 
			
		||||
    eh = MockEventHandlerFactory().build(mock_eh, keys=test_keys)
 | 
			
		||||
 | 
			
		||||
    assert mock_eh.started
 | 
			
		||||
    assert mock_eh.server_connector is eh.server_connector
 | 
			
		||||
    assert eh.server_connector.keys == test_keys
 | 
			
		||||
    assert not mock_eh.last_message
 | 
			
		||||
    eh.server_connector.simulate_message(test_msg)
 | 
			
		||||
    assert mock_eh.last_message == test_msg
 | 
			
		||||
    assert not mock_eh.cleaned_up
 | 
			
		||||
    eh.stop()
 | 
			
		||||
    assert mock_eh.cleaned_up
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_build_from_function(test_keys, test_msg):
 | 
			
		||||
    def some_function(message, server_connector):
 | 
			
		||||
        raise RuntimeError(message, server_connector.keys)
 | 
			
		||||
    eh = MockEventHandlerFactory().build(some_function, keys=test_keys)
 | 
			
		||||
 | 
			
		||||
    assert eh.server_connector.keys == test_keys
 | 
			
		||||
    with pytest.raises(RuntimeError) as err:
 | 
			
		||||
        eh.server_connector.simulate_message(test_msg)
 | 
			
		||||
        msg, keys = err.args
 | 
			
		||||
        assert msg == test_msg
 | 
			
		||||
        assert keys == test_keys
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_build_from_lambda(test_keys, test_msg):
 | 
			
		||||
    def assert_messages_equal(msg):
 | 
			
		||||
        assert msg == test_msg
 | 
			
		||||
    fun = lambda msg, sc: assert_messages_equal(msg)
 | 
			
		||||
    eh = MockEventHandlerFactory().build(fun, keys=test_keys)
 | 
			
		||||
    eh.server_connector.simulate_message(test_msg)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_build_raises_if_no_key(test_keys):
 | 
			
		||||
    eh = MockEventHandler()
 | 
			
		||||
    with pytest.raises(ValueError):
 | 
			
		||||
        MockEventHandlerFactory().build(eh)
 | 
			
		||||
 | 
			
		||||
    def handle_event(*_):
 | 
			
		||||
        pass
 | 
			
		||||
    with pytest.raises(ValueError):
 | 
			
		||||
        MockEventHandlerFactory().build(handle_event)
 | 
			
		||||
 | 
			
		||||
    with pytest.raises(ValueError):
 | 
			
		||||
        MockEventHandlerFactory().build(lambda msg, sc: None)
 | 
			
		||||
 | 
			
		||||
    WithKeysEventHandler = EventHandler
 | 
			
		||||
    WithKeysEventHandler.keys = test_keys
 | 
			
		||||
    MockEventHandlerFactory().build(eh, event_handler_type=WithKeysEventHandler)
 | 
			
		||||
 | 
			
		||||
    eh.keys = test_keys
 | 
			
		||||
    MockEventHandlerFactory().build(eh)
 | 
			
		||||
							
								
								
									
										6
									
								
								tfw/internals/inotify/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								tfw/internals/inotify/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,6 @@
 | 
			
		||||
from .inotify import InotifyObserver
 | 
			
		||||
from .inotify import (
 | 
			
		||||
    InotifyFileCreatedEvent, InotifyFileModifiedEvent, InotifyFileMovedEvent,
 | 
			
		||||
    InotifyFileDeletedEvent, InotifyDirCreatedEvent, InotifyDirModifiedEvent,
 | 
			
		||||
    InotifyDirMovedEvent, InotifyDirDeletedEvent
 | 
			
		||||
)
 | 
			
		||||
							
								
								
									
										189
									
								
								tfw/internals/inotify/inotify.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										189
									
								
								tfw/internals/inotify/inotify.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,189 @@
 | 
			
		||||
# pylint: disable=too-few-public-methods
 | 
			
		||||
 | 
			
		||||
from typing import Iterable
 | 
			
		||||
from time import time
 | 
			
		||||
from os.path import abspath, dirname, isdir
 | 
			
		||||
 | 
			
		||||
from watchdog.observers import Observer
 | 
			
		||||
from watchdog.events import FileSystemMovedEvent, PatternMatchingEventHandler
 | 
			
		||||
from watchdog.events import (
 | 
			
		||||
    FileCreatedEvent, FileModifiedEvent, FileMovedEvent, FileDeletedEvent,
 | 
			
		||||
    DirCreatedEvent, DirModifiedEvent, DirMovedEvent, DirDeletedEvent
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InotifyEvent:
 | 
			
		||||
    def __init__(self, src_path):
 | 
			
		||||
        self.date = time()
 | 
			
		||||
        self.src_path = src_path
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return self.__repr__()
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return f'{self.__class__.__name__}({self.src_path})'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InotifyMovedEvent(InotifyEvent):
 | 
			
		||||
    def __init__(self, src_path, dest_path):
 | 
			
		||||
        self.dest_path = dest_path
 | 
			
		||||
        super().__init__(src_path)
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return f'{self.__class__.__name__}({self.src_path}, {self.dest_path})'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InotifyFileCreatedEvent(InotifyEvent):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InotifyFileModifiedEvent(InotifyEvent):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InotifyFileMovedEvent(InotifyMovedEvent):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InotifyFileDeletedEvent(InotifyEvent):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InotifyDirCreatedEvent(InotifyEvent):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InotifyDirModifiedEvent(InotifyEvent):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InotifyDirMovedEvent(InotifyMovedEvent):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InotifyDirDeletedEvent(InotifyEvent):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InotifyObserver:
 | 
			
		||||
    def __init__(self, path, patterns=None, exclude=None, recursive=False):
 | 
			
		||||
        self._files = []
 | 
			
		||||
        self._paths = path
 | 
			
		||||
        self._patterns = patterns or []
 | 
			
		||||
        self._exclude = exclude
 | 
			
		||||
        self._recursive = recursive
 | 
			
		||||
        self._observer = Observer()
 | 
			
		||||
        self._reset()
 | 
			
		||||
 | 
			
		||||
    def _reset(self):
 | 
			
		||||
        if isinstance(self._paths, str):
 | 
			
		||||
            self._paths = [self._paths]
 | 
			
		||||
        if isinstance(self._paths, Iterable):
 | 
			
		||||
            self._extract_files_from_paths()
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError('Expected one or more string paths.')
 | 
			
		||||
 | 
			
		||||
        patterns = self._files+self.patterns
 | 
			
		||||
        handler = PatternMatchingEventHandler(patterns if patterns else None, self.exclude)
 | 
			
		||||
        handler.on_any_event = self._dispatch_event
 | 
			
		||||
        self._observer.unschedule_all()
 | 
			
		||||
        for path in self.paths:
 | 
			
		||||
            self._observer.schedule(handler, path, self._recursive)
 | 
			
		||||
 | 
			
		||||
    def _extract_files_from_paths(self):
 | 
			
		||||
        files, paths = [], []
 | 
			
		||||
        for path in self._paths:
 | 
			
		||||
            path = abspath(path)
 | 
			
		||||
            if isdir(path):
 | 
			
		||||
                paths.append(path)
 | 
			
		||||
            else:
 | 
			
		||||
                paths.append(dirname(path))
 | 
			
		||||
                files.append(path)
 | 
			
		||||
        self._files, self._paths = files, paths
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def paths(self):
 | 
			
		||||
        return self._paths
 | 
			
		||||
 | 
			
		||||
    @paths.setter
 | 
			
		||||
    def paths(self, paths):
 | 
			
		||||
        self._paths = paths
 | 
			
		||||
        self._reset()
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def patterns(self):
 | 
			
		||||
        return self._patterns
 | 
			
		||||
 | 
			
		||||
    @patterns.setter
 | 
			
		||||
    def patterns(self, patterns):
 | 
			
		||||
        self._patterns = patterns or []
 | 
			
		||||
        self._reset()
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def exclude(self):
 | 
			
		||||
        return self._exclude
 | 
			
		||||
 | 
			
		||||
    @exclude.setter
 | 
			
		||||
    def exclude(self, exclude):
 | 
			
		||||
        self._exclude = exclude
 | 
			
		||||
        self._reset()
 | 
			
		||||
 | 
			
		||||
    def start(self):
 | 
			
		||||
        self._observer.start()
 | 
			
		||||
 | 
			
		||||
    def stop(self):
 | 
			
		||||
        self._observer.stop()
 | 
			
		||||
        self._observer.join()
 | 
			
		||||
 | 
			
		||||
    def _dispatch_event(self, event):
 | 
			
		||||
        event_to_action = {
 | 
			
		||||
            InotifyFileCreatedEvent  : self.on_created,
 | 
			
		||||
            InotifyFileModifiedEvent : self.on_modified,
 | 
			
		||||
            InotifyFileMovedEvent    : self.on_moved,
 | 
			
		||||
            InotifyFileDeletedEvent  : self.on_deleted,
 | 
			
		||||
            InotifyDirCreatedEvent   : self.on_created,
 | 
			
		||||
            InotifyDirModifiedEvent  : self.on_modified,
 | 
			
		||||
            InotifyDirMovedEvent     : self.on_moved,
 | 
			
		||||
            InotifyDirDeletedEvent   : self.on_deleted
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        event = self._transform_event(event)
 | 
			
		||||
        self.on_any_event(event)
 | 
			
		||||
        event_to_action[type(event)](event)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _transform_event(event):
 | 
			
		||||
        watchdog_to_inotify = {
 | 
			
		||||
            FileCreatedEvent  : InotifyFileCreatedEvent,
 | 
			
		||||
            FileModifiedEvent : InotifyFileModifiedEvent,
 | 
			
		||||
            FileMovedEvent    : InotifyFileMovedEvent,
 | 
			
		||||
            FileDeletedEvent  : InotifyFileDeletedEvent,
 | 
			
		||||
            DirCreatedEvent   : InotifyDirCreatedEvent,
 | 
			
		||||
            DirModifiedEvent  : InotifyDirModifiedEvent,
 | 
			
		||||
            DirMovedEvent     : InotifyDirMovedEvent,
 | 
			
		||||
            DirDeletedEvent   : InotifyDirDeletedEvent
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            cls = watchdog_to_inotify[type(event)]
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            raise NameError('Watchdog API returned an unknown event.')
 | 
			
		||||
 | 
			
		||||
        if isinstance(event, FileSystemMovedEvent):
 | 
			
		||||
            return cls(event.src_path, event.dest_path)
 | 
			
		||||
        return cls(event.src_path)
 | 
			
		||||
 | 
			
		||||
    def on_any_event(self, event):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def on_created(self, event):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def on_modified(self, event):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def on_moved(self, event):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def on_deleted(self, event):
 | 
			
		||||
        pass
 | 
			
		||||
							
								
								
									
										179
									
								
								tfw/internals/inotify/test_inotify.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										179
									
								
								tfw/internals/inotify/test_inotify.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,179 @@
 | 
			
		||||
# pylint: disable=redefined-outer-name
 | 
			
		||||
 | 
			
		||||
from queue import Empty, Queue
 | 
			
		||||
from secrets import token_urlsafe
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from shutil import rmtree
 | 
			
		||||
from os.path import join
 | 
			
		||||
from os import mkdir, remove, rename
 | 
			
		||||
from tempfile import TemporaryDirectory
 | 
			
		||||
from contextlib import suppress
 | 
			
		||||
 | 
			
		||||
import watchdog
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
from .inotify import InotifyObserver
 | 
			
		||||
from .inotify import (
 | 
			
		||||
    InotifyFileCreatedEvent, InotifyFileModifiedEvent, InotifyFileMovedEvent,
 | 
			
		||||
    InotifyFileDeletedEvent, InotifyDirCreatedEvent, InotifyDirModifiedEvent,
 | 
			
		||||
    InotifyDirMovedEvent, InotifyDirDeletedEvent
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
with suppress(AttributeError):
 | 
			
		||||
    watchdog.observers.inotify_buffer.InotifyBuffer.delay = 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InotifyContext:
 | 
			
		||||
    def __init__(self, workdir, subdir, subfile, observer):
 | 
			
		||||
        self.missing_events = 0
 | 
			
		||||
        self.workdir = workdir
 | 
			
		||||
        self.subdir = subdir
 | 
			
		||||
        self.subfile = subfile
 | 
			
		||||
        self.observer = observer
 | 
			
		||||
 | 
			
		||||
        self.event_to_queue = {
 | 
			
		||||
            InotifyFileCreatedEvent  : self.observer.create_queue,
 | 
			
		||||
            InotifyFileModifiedEvent : self.observer.modify_queue,
 | 
			
		||||
            InotifyFileMovedEvent    : self.observer.move_queue,
 | 
			
		||||
            InotifyFileDeletedEvent  : self.observer.delete_queue,
 | 
			
		||||
            InotifyDirCreatedEvent   : self.observer.create_queue,
 | 
			
		||||
            InotifyDirModifiedEvent  : self.observer.modify_queue,
 | 
			
		||||
            InotifyDirMovedEvent     : self.observer.move_queue,
 | 
			
		||||
            InotifyDirDeletedEvent   : self.observer.delete_queue
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    def create_random_file(self, dirname, extension):
 | 
			
		||||
        filename = self.join(f'{dirname}/{generate_name()}{extension}')
 | 
			
		||||
        Path(filename).touch()
 | 
			
		||||
        return filename
 | 
			
		||||
 | 
			
		||||
    def create_random_folder(self, basepath):
 | 
			
		||||
        dirname = self.join(f'{basepath}/{generate_name()}')
 | 
			
		||||
        mkdir(dirname)
 | 
			
		||||
        return dirname
 | 
			
		||||
 | 
			
		||||
    def join(self, path):
 | 
			
		||||
        return join(self.workdir, path)
 | 
			
		||||
 | 
			
		||||
    def check_event(self, event_type, path):
 | 
			
		||||
        self.missing_events += 1
 | 
			
		||||
        event = self.event_to_queue[event_type].get(timeout=0.1)
 | 
			
		||||
        assert isinstance(event, event_type)
 | 
			
		||||
        assert event.src_path == path
 | 
			
		||||
        return event
 | 
			
		||||
 | 
			
		||||
    def check_empty(self, event_type):
 | 
			
		||||
        with pytest.raises(Empty):
 | 
			
		||||
            self.event_to_queue[event_type].get(timeout=0.1)
 | 
			
		||||
 | 
			
		||||
    def check_any(self):
 | 
			
		||||
        attrs = self.observer.__dict__.values()
 | 
			
		||||
        total = sum([q.qsize() for q in attrs if isinstance(q, Queue)])
 | 
			
		||||
        return total+self.missing_events == len(self.observer.any_list)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InotifyTestObserver(InotifyObserver):
 | 
			
		||||
    def __init__(self, paths, patterns=None, exclude=None, recursive=False):
 | 
			
		||||
        self.any_list = []
 | 
			
		||||
        self.create_queue, self.modify_queue, self.move_queue, self.delete_queue = [Queue() for _ in range(4)]
 | 
			
		||||
        super().__init__(paths, patterns, exclude, recursive)
 | 
			
		||||
 | 
			
		||||
    def on_any_event(self, event):
 | 
			
		||||
        self.any_list.append(event)
 | 
			
		||||
 | 
			
		||||
    def on_created(self, event):
 | 
			
		||||
        self.create_queue.put(event)
 | 
			
		||||
 | 
			
		||||
    def on_modified(self, event):
 | 
			
		||||
        self.modify_queue.put(event)
 | 
			
		||||
 | 
			
		||||
    def on_moved(self, event):
 | 
			
		||||
        self.move_queue.put(event)
 | 
			
		||||
 | 
			
		||||
    def on_deleted(self, event):
 | 
			
		||||
        self.delete_queue.put(event)
 | 
			
		||||
 | 
			
		||||
def generate_name():
 | 
			
		||||
    return token_urlsafe(16)
 | 
			
		||||
 | 
			
		||||
@pytest.fixture()
 | 
			
		||||
def context():
 | 
			
		||||
    with TemporaryDirectory() as workdir:
 | 
			
		||||
        subdir = join(workdir, generate_name())
 | 
			
		||||
        subfile = join(subdir, generate_name()+'.txt')
 | 
			
		||||
        mkdir(subdir)
 | 
			
		||||
        Path(subfile).touch()
 | 
			
		||||
        monitor = InotifyTestObserver(workdir, recursive=True)
 | 
			
		||||
        monitor.start()
 | 
			
		||||
        yield InotifyContext(workdir, subdir, subfile, monitor)
 | 
			
		||||
 | 
			
		||||
def test_create(context):
 | 
			
		||||
    newfile = context.create_random_file(context.workdir, '.txt')
 | 
			
		||||
    context.check_event(InotifyFileCreatedEvent, newfile)
 | 
			
		||||
    newdir = context.create_random_folder(context.workdir)
 | 
			
		||||
    context.check_event(InotifyDirCreatedEvent, newdir)
 | 
			
		||||
    assert context.check_any()
 | 
			
		||||
 | 
			
		||||
def test_modify(context):
 | 
			
		||||
    with open(context.subfile, 'wb', buffering=0) as ofile:
 | 
			
		||||
        ofile.write(b'text')
 | 
			
		||||
    context.check_event(InotifyFileModifiedEvent, context.subfile)
 | 
			
		||||
    while True:
 | 
			
		||||
        try:
 | 
			
		||||
            context.observer.modify_queue.get(timeout=0.1)
 | 
			
		||||
            context.missing_events += 1
 | 
			
		||||
        except Empty:
 | 
			
		||||
            break
 | 
			
		||||
    rename(context.subfile, context.subfile+'_new')
 | 
			
		||||
    context.check_event(InotifyDirModifiedEvent, context.subdir)
 | 
			
		||||
    assert context.check_any()
 | 
			
		||||
 | 
			
		||||
def test_move(context):
 | 
			
		||||
    rename(context.subdir, context.subdir+'_new')
 | 
			
		||||
    context.check_event(InotifyDirMovedEvent, context.subdir)
 | 
			
		||||
    context.check_event(InotifyFileMovedEvent, context.subfile)
 | 
			
		||||
    assert context.check_any()
 | 
			
		||||
 | 
			
		||||
def test_delete(context):
 | 
			
		||||
    rmtree(context.subdir)
 | 
			
		||||
    context.check_event(InotifyFileDeletedEvent, context.subfile)
 | 
			
		||||
    context.check_event(InotifyDirDeletedEvent, context.subdir)
 | 
			
		||||
    assert context.check_any()
 | 
			
		||||
 | 
			
		||||
def test_paths(context):
 | 
			
		||||
    context.observer.paths = context.subdir
 | 
			
		||||
    newdir = context.create_random_folder(context.workdir)
 | 
			
		||||
    newfile = context.create_random_file(context.subdir, '.txt')
 | 
			
		||||
    context.check_event(InotifyDirModifiedEvent, context.subdir)
 | 
			
		||||
    context.check_event(InotifyFileCreatedEvent, newfile)
 | 
			
		||||
    context.observer.paths = [newdir, newfile]
 | 
			
		||||
    remove(newfile)
 | 
			
		||||
    context.check_event(InotifyFileDeletedEvent, newfile)
 | 
			
		||||
    assert context.check_any()
 | 
			
		||||
    context.observer.paths = context.workdir
 | 
			
		||||
 | 
			
		||||
def test_patterns(context):
 | 
			
		||||
    context.observer.patterns = ['*.txt']
 | 
			
		||||
    context.create_random_file(context.subdir, '.bin')
 | 
			
		||||
    newfile = context.create_random_file(context.subdir, '.txt')
 | 
			
		||||
    context.check_event(InotifyFileCreatedEvent, newfile)
 | 
			
		||||
    context.check_empty(InotifyFileCreatedEvent)
 | 
			
		||||
    assert context.check_any()
 | 
			
		||||
    context.observer.patterns = None
 | 
			
		||||
 | 
			
		||||
def test_exclude(context):
 | 
			
		||||
    context.observer.exclude = ['*.txt']
 | 
			
		||||
    context.create_random_file(context.subdir, '.txt')
 | 
			
		||||
    newfile = context.create_random_file(context.subdir, '.bin')
 | 
			
		||||
    context.check_event(InotifyFileCreatedEvent, newfile)
 | 
			
		||||
    context.check_empty(InotifyFileCreatedEvent)
 | 
			
		||||
    assert context.check_any()
 | 
			
		||||
    context.observer.exclude = None
 | 
			
		||||
 | 
			
		||||
def test_stress(context):
 | 
			
		||||
    newfile = []
 | 
			
		||||
    for i in range(1024):
 | 
			
		||||
        newfile.append(context.create_random_file(context.subdir, '.txt'))
 | 
			
		||||
    for i in range(1024):
 | 
			
		||||
        context.check_event(InotifyFileCreatedEvent, newfile[i])
 | 
			
		||||
    assert context.check_any()
 | 
			
		||||
							
								
								
									
										27
									
								
								tfw/internals/lazy.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								tfw/internals/lazy.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,27 @@
 | 
			
		||||
from functools import update_wrapper, wraps
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class lazy_property:
 | 
			
		||||
    """
 | 
			
		||||
    Decorator that replaces a function with the value
 | 
			
		||||
    it calculates on the first call.
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self, func):
 | 
			
		||||
        self.func = func
 | 
			
		||||
        update_wrapper(self, func)
 | 
			
		||||
 | 
			
		||||
    def __get__(self, instance, owner):
 | 
			
		||||
        if instance is None:
 | 
			
		||||
            return self  # avoids potential __new__ TypeError
 | 
			
		||||
        value = self.func(instance)
 | 
			
		||||
        setattr(instance, self.func.__name__, value)
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def lazy_factory(fun):
 | 
			
		||||
    class wrapper:
 | 
			
		||||
        @wraps(fun)
 | 
			
		||||
        @lazy_property
 | 
			
		||||
        def instance(self):  # pylint: disable=no-self-use
 | 
			
		||||
            return fun()
 | 
			
		||||
    return wrapper()
 | 
			
		||||
							
								
								
									
										4
									
								
								tfw/internals/networking/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								tfw/internals/networking/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,4 @@
 | 
			
		||||
from .serialization import serialize_tfw_msg, deserialize_tfw_msg, with_deserialize_tfw_msg, message_bytes
 | 
			
		||||
from .server_connector import ServerUplinkConnector, ServerDownlinkConnector, ServerConnector
 | 
			
		||||
from .event_handler_connector import EventHandlerConnector
 | 
			
		||||
from .scope import Scope
 | 
			
		||||
							
								
								
									
										48
									
								
								tfw/internals/networking/event_handler_connector.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								tfw/internals/networking/event_handler_connector.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,48 @@
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
import zmq
 | 
			
		||||
from zmq.eventloop.zmqstream import ZMQStream
 | 
			
		||||
 | 
			
		||||
from .serialization import serialize_tfw_msg, with_deserialize_tfw_msg
 | 
			
		||||
 | 
			
		||||
LOG = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EventHandlerDownlinkConnector:
 | 
			
		||||
    def __init__(self, bind_addr):
 | 
			
		||||
        self._zmq_pull_socket = zmq.Context.instance().socket(zmq.PULL)
 | 
			
		||||
        self._zmq_pull_socket.setsockopt(zmq.RCVHWM, 0)
 | 
			
		||||
        self._zmq_pull_stream = ZMQStream(self._zmq_pull_socket)
 | 
			
		||||
        self._zmq_pull_socket.bind(bind_addr)
 | 
			
		||||
        LOG.debug('Pull socket bound to %s', bind_addr)
 | 
			
		||||
 | 
			
		||||
    def register_callback(self, callback):
 | 
			
		||||
        callback = with_deserialize_tfw_msg(callback)
 | 
			
		||||
        self._zmq_pull_stream.on_recv(callback)
 | 
			
		||||
 | 
			
		||||
    def close(self):
 | 
			
		||||
        self._zmq_pull_stream.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EventHandlerUplinkConnector:
 | 
			
		||||
    def __init__(self, bind_addr):
 | 
			
		||||
        self._zmq_pub_socket = zmq.Context.instance().socket(zmq.PUB)
 | 
			
		||||
        self._zmq_pub_socket.setsockopt(zmq.SNDHWM, 0)
 | 
			
		||||
        self._zmq_pub_socket.bind(bind_addr)
 | 
			
		||||
        LOG.debug('Pub socket bound to %s', bind_addr)
 | 
			
		||||
 | 
			
		||||
    def send_message(self, message: dict):
 | 
			
		||||
        self._zmq_pub_socket.send_multipart(serialize_tfw_msg(message))
 | 
			
		||||
 | 
			
		||||
    def close(self):
 | 
			
		||||
        self._zmq_pub_socket.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EventHandlerConnector(EventHandlerDownlinkConnector, EventHandlerUplinkConnector):
 | 
			
		||||
    def __init__(self, downlink_bind_addr, uplink_bind_addr):
 | 
			
		||||
        EventHandlerDownlinkConnector.__init__(self, downlink_bind_addr)
 | 
			
		||||
        EventHandlerUplinkConnector.__init__(self, uplink_bind_addr)
 | 
			
		||||
 | 
			
		||||
    def close(self):
 | 
			
		||||
        EventHandlerDownlinkConnector.close(self)
 | 
			
		||||
        EventHandlerUplinkConnector.close(self)
 | 
			
		||||
							
								
								
									
										7
									
								
								tfw/internals/networking/scope.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								tfw/internals/networking/scope.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,7 @@
 | 
			
		||||
from enum import Enum
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Scope(Enum):
 | 
			
		||||
    ZMQ = 'zmq'
 | 
			
		||||
    WEBSOCKET = 'websocket'
 | 
			
		||||
    BROADCAST = 'broadcast'
 | 
			
		||||
							
								
								
									
										104
									
								
								tfw/internals/networking/serialization.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										104
									
								
								tfw/internals/networking/serialization.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,104 @@
 | 
			
		||||
"""
 | 
			
		||||
TFW JSON message format
 | 
			
		||||
 | 
			
		||||
message:
 | 
			
		||||
{
 | 
			
		||||
    "key":     string,    # addressing
 | 
			
		||||
    "data":    {...},     # payload
 | 
			
		||||
    "trigger": string     # FSM trigger
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
ZeroMQ's sub-pub sockets use enveloped messages
 | 
			
		||||
(http://zguide.zeromq.org/page:all#Pub-Sub-Message-Envelopes)
 | 
			
		||||
and TFW also uses them internally. This means that on ZMQ sockets
 | 
			
		||||
we always send the messages key separately and then the actual
 | 
			
		||||
message (which contains the key as well) like so:
 | 
			
		||||
 | 
			
		||||
socket.send_multipart([message['key'], message])
 | 
			
		||||
 | 
			
		||||
The purpose of this module is abstracting away this low level behaviour.
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
from functools import wraps
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def serialize_tfw_msg(message):
 | 
			
		||||
    """
 | 
			
		||||
    Create TFW multipart data from message dict
 | 
			
		||||
    """
 | 
			
		||||
    return _serialize_all(message['key'], message)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def with_deserialize_tfw_msg(fun):
 | 
			
		||||
    @wraps(fun)
 | 
			
		||||
    def wrapper(message_parts):
 | 
			
		||||
        message = deserialize_tfw_msg(*message_parts)
 | 
			
		||||
        return fun(message)
 | 
			
		||||
    return wrapper
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def deserialize_tfw_msg(*args):
 | 
			
		||||
    """
 | 
			
		||||
    Return message from TFW multipart data
 | 
			
		||||
    """
 | 
			
		||||
    return _deserialize_all(*args)[1]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _serialize_all(*args):
 | 
			
		||||
    return tuple(
 | 
			
		||||
        _serialize_single(arg)
 | 
			
		||||
        for arg in args
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _deserialize_all(*args):
 | 
			
		||||
    return tuple(
 | 
			
		||||
        _deserialize_single(arg)
 | 
			
		||||
        for arg in args
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _serialize_single(data):
 | 
			
		||||
    """
 | 
			
		||||
    Return input as bytes
 | 
			
		||||
    (serialize input if it is JSON)
 | 
			
		||||
    """
 | 
			
		||||
    if not isinstance(data, str):
 | 
			
		||||
        data = message_bytes(data)
 | 
			
		||||
    return _encode_if_needed(data)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def message_bytes(message):
 | 
			
		||||
    return json.dumps(message, sort_keys=True).encode()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _deserialize_single(data):
 | 
			
		||||
    """
 | 
			
		||||
    Try parsing input as JSON, return it as
 | 
			
		||||
    string if parsing fails.
 | 
			
		||||
    """
 | 
			
		||||
    try:
 | 
			
		||||
        return json.loads(data)
 | 
			
		||||
    except ValueError:
 | 
			
		||||
        return _decode_if_needed(data)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _encode_if_needed(value):
 | 
			
		||||
    """
 | 
			
		||||
    Return input as bytes
 | 
			
		||||
    (encode if input is string)
 | 
			
		||||
    """
 | 
			
		||||
    if isinstance(value, str):
 | 
			
		||||
        value = value.encode('utf-8')
 | 
			
		||||
    return value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _decode_if_needed(value):
 | 
			
		||||
    """
 | 
			
		||||
    Return input as string
 | 
			
		||||
    (decode if input is bytes)
 | 
			
		||||
    """
 | 
			
		||||
    if isinstance(value, (bytes, bytearray)):
 | 
			
		||||
        value = value.decode('utf-8')
 | 
			
		||||
    return value
 | 
			
		||||
							
								
								
									
										65
									
								
								tfw/internals/networking/server_connector.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										65
									
								
								tfw/internals/networking/server_connector.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,65 @@
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
import zmq
 | 
			
		||||
from zmq.eventloop.zmqstream import ZMQStream
 | 
			
		||||
 | 
			
		||||
from .scope import Scope
 | 
			
		||||
from .serialization import serialize_tfw_msg, with_deserialize_tfw_msg
 | 
			
		||||
 | 
			
		||||
LOG = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ServerDownlinkConnector:
 | 
			
		||||
    def __init__(self, connect_addr):
 | 
			
		||||
        self.keys = []
 | 
			
		||||
        self._on_recv_callback = None
 | 
			
		||||
        self._zmq_sub_socket = zmq.Context.instance().socket(zmq.SUB)
 | 
			
		||||
        self._zmq_sub_socket.setsockopt(zmq.RCVHWM, 0)
 | 
			
		||||
        self._zmq_sub_socket.connect(connect_addr)
 | 
			
		||||
        self._zmq_sub_stream = ZMQStream(self._zmq_sub_socket)
 | 
			
		||||
 | 
			
		||||
    def subscribe(self, *keys):
 | 
			
		||||
        for key in keys:
 | 
			
		||||
            self._zmq_sub_socket.setsockopt_string(zmq.SUBSCRIBE, key)
 | 
			
		||||
            self.keys.append(key)
 | 
			
		||||
 | 
			
		||||
    def unsubscribe(self, *keys):
 | 
			
		||||
        for key in keys:
 | 
			
		||||
            self._zmq_sub_socket.setsockopt_string(zmq.UNSUBSCRIBE, key)
 | 
			
		||||
            self.keys.remove(key)
 | 
			
		||||
 | 
			
		||||
    def register_callback(self, callback):
 | 
			
		||||
        self._on_recv_callback = callback
 | 
			
		||||
        self._zmq_sub_stream.on_recv(with_deserialize_tfw_msg(self._on_recv))
 | 
			
		||||
 | 
			
		||||
    def _on_recv(self, message):
 | 
			
		||||
        key = message['key']
 | 
			
		||||
        if key in self.keys or '' in self.keys:
 | 
			
		||||
            self._on_recv_callback(message)
 | 
			
		||||
 | 
			
		||||
    def close(self):
 | 
			
		||||
        self._zmq_sub_stream.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ServerUplinkConnector:
 | 
			
		||||
    def __init__(self, connect_addr):
 | 
			
		||||
        self._zmq_push_socket = zmq.Context.instance().socket(zmq.PUSH)
 | 
			
		||||
        self._zmq_push_socket.setsockopt(zmq.SNDHWM, 0)
 | 
			
		||||
        self._zmq_push_socket.connect(connect_addr)
 | 
			
		||||
 | 
			
		||||
    def send_message(self, message, scope=Scope.ZMQ):
 | 
			
		||||
        message['scope'] = scope.value
 | 
			
		||||
        self._zmq_push_socket.send_multipart(serialize_tfw_msg(message))
 | 
			
		||||
 | 
			
		||||
    def close(self):
 | 
			
		||||
        self._zmq_push_socket.close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ServerConnector(ServerDownlinkConnector, ServerUplinkConnector):
 | 
			
		||||
    def __init__(self, downlink_connect_addr, uplink_connect_addr):
 | 
			
		||||
        ServerDownlinkConnector.__init__(self, downlink_connect_addr)
 | 
			
		||||
        ServerUplinkConnector.__init__(self, uplink_connect_addr)
 | 
			
		||||
 | 
			
		||||
    def close(self):
 | 
			
		||||
        ServerDownlinkConnector.close(self)
 | 
			
		||||
        ServerUplinkConnector.close(self)
 | 
			
		||||
							
								
								
									
										1
									
								
								tfw/internals/server/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								tfw/internals/server/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1 @@
 | 
			
		||||
from .zmq_websocket_router import ZMQWebSocketRouter
 | 
			
		||||
							
								
								
									
										69
									
								
								tfw/internals/server/zmq_websocket_router.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								tfw/internals/server/zmq_websocket_router.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,69 @@
 | 
			
		||||
import json
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
from tornado.websocket import WebSocketHandler
 | 
			
		||||
 | 
			
		||||
from tfw.internals.networking import Scope
 | 
			
		||||
 | 
			
		||||
LOG = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ZMQWebSocketRouter(WebSocketHandler):
 | 
			
		||||
    # pylint: disable=abstract-method,attribute-defined-outside-init
 | 
			
		||||
    instances = set()
 | 
			
		||||
 | 
			
		||||
    def initialize(self, **kwargs):
 | 
			
		||||
        self.event_handler_connector = kwargs['event_handler_connector']
 | 
			
		||||
        self.tfw_router = TFWRouter(self.send_to_zmq, self.send_to_websockets)
 | 
			
		||||
 | 
			
		||||
    def send_to_zmq(self, message):
 | 
			
		||||
        self.event_handler_connector.send_message(message)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def send_to_websockets(cls, message):
 | 
			
		||||
        for instance in cls.instances:
 | 
			
		||||
            instance.write_message(message)
 | 
			
		||||
 | 
			
		||||
    def prepare(self):
 | 
			
		||||
        type(self).instances.add(self)
 | 
			
		||||
 | 
			
		||||
    def on_close(self):
 | 
			
		||||
        type(self).instances.remove(self)
 | 
			
		||||
 | 
			
		||||
    def open(self, *args, **kwargs):
 | 
			
		||||
        LOG.debug('WebSocket connection initiated!')
 | 
			
		||||
        self.event_handler_connector.register_callback(self.zmq_callback)
 | 
			
		||||
 | 
			
		||||
    def zmq_callback(self, message):
 | 
			
		||||
        LOG.debug('Received on ZMQ pull socket: %s', message)
 | 
			
		||||
        self.tfw_router.route(message)
 | 
			
		||||
 | 
			
		||||
    def on_message(self, message):
 | 
			
		||||
        message = json.loads(message)
 | 
			
		||||
        LOG.debug('Received on WebSocket: %s', message)
 | 
			
		||||
        self.tfw_router.route(message)
 | 
			
		||||
 | 
			
		||||
    # much secure, very cors, wow
 | 
			
		||||
    def check_origin(self, origin):
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TFWRouter:
 | 
			
		||||
    def __init__(self, send_to_zmq, send_to_websockets):
 | 
			
		||||
        self.send_to_zmq = send_to_zmq
 | 
			
		||||
        self.send_to_websockets = send_to_websockets
 | 
			
		||||
 | 
			
		||||
    def route(self, message):
 | 
			
		||||
        scope = Scope(message.pop('scope', 'zmq'))
 | 
			
		||||
 | 
			
		||||
        routing_table = {
 | 
			
		||||
            Scope.ZMQ: self.send_to_zmq,
 | 
			
		||||
            Scope.WEBSOCKET: self.send_to_websockets,
 | 
			
		||||
            Scope.BROADCAST: self.broadcast
 | 
			
		||||
        }
 | 
			
		||||
        action = routing_table[scope]
 | 
			
		||||
        action(message)
 | 
			
		||||
 | 
			
		||||
    def broadcast(self, message):
 | 
			
		||||
        self.send_to_zmq(message)
 | 
			
		||||
        self.send_to_websockets(message)
 | 
			
		||||
							
								
								
									
										132
									
								
								tfw/logging.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										132
									
								
								tfw/logging.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,132 @@
 | 
			
		||||
# pylint: disable=bad-whitespace
 | 
			
		||||
from datetime import datetime
 | 
			
		||||
from typing import TextIO, Union
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from traceback import format_exception
 | 
			
		||||
from logging import DEBUG, getLogger, Handler, Formatter, Filter
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Color:
 | 
			
		||||
    RED     = '\033[31m'
 | 
			
		||||
    GREEN   = '\033[32m'
 | 
			
		||||
    YELLOW  = '\033[33m'
 | 
			
		||||
    BLUE    = '\033[34m'
 | 
			
		||||
    CYAN    = '\033[36m'
 | 
			
		||||
    WHITE   = '\033[37m'
 | 
			
		||||
    RESET   = '\033[0m'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class Log:
 | 
			
		||||
    stream: Union[str, TextIO]
 | 
			
		||||
    formatter: Formatter
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Logger:
 | 
			
		||||
    def __init__(self, logs, level=DEBUG):
 | 
			
		||||
        self.root_logger = getLogger()
 | 
			
		||||
        self.old_level = self.root_logger.level
 | 
			
		||||
        self.new_level = level
 | 
			
		||||
        self.handlers = []
 | 
			
		||||
        for log in logs:
 | 
			
		||||
            handler = LogHandler(log.stream)
 | 
			
		||||
            handler.setFormatter(log.formatter)
 | 
			
		||||
            self.handlers.append(handler)
 | 
			
		||||
 | 
			
		||||
    def start(self):
 | 
			
		||||
        self.root_logger.setLevel(self.new_level)
 | 
			
		||||
        for handler in self.handlers:
 | 
			
		||||
            self.root_logger.addHandler(handler)
 | 
			
		||||
 | 
			
		||||
    def stop(self):
 | 
			
		||||
        self.root_logger.setLevel(self.old_level)
 | 
			
		||||
        for handler in self.handlers:
 | 
			
		||||
            handler.close()
 | 
			
		||||
            self.root_logger.removeHandler(handler)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LogHandler(Handler):
 | 
			
		||||
    def __init__(self, stream):
 | 
			
		||||
        if isinstance(stream, str):
 | 
			
		||||
            self.stream = open(stream, 'a+')
 | 
			
		||||
            self.close_stream = True
 | 
			
		||||
        else:
 | 
			
		||||
            self.stream = stream
 | 
			
		||||
            self.close_stream = False
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    def emit(self, record):
 | 
			
		||||
        entry = self.format(record)
 | 
			
		||||
        self.stream.write(entry+'\n')
 | 
			
		||||
        self.stream.flush()
 | 
			
		||||
 | 
			
		||||
    def close(self):
 | 
			
		||||
        if self.close_stream:
 | 
			
		||||
            self.stream.close()
 | 
			
		||||
 | 
			
		||||
class LogFormatter(Formatter):
 | 
			
		||||
    severity_to_color = {
 | 
			
		||||
        'CRITICAL' : Color.RED,
 | 
			
		||||
        'ERROR'    : Color.RED,
 | 
			
		||||
        'WARNING'  : Color.YELLOW,
 | 
			
		||||
        'INFO'     : Color.GREEN,
 | 
			
		||||
        'DEBUG'    : Color.BLUE,
 | 
			
		||||
        'NOTSET'   : Color.CYAN
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    def __init__(self, limit):
 | 
			
		||||
        self.limit = limit
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    def format(self, record):
 | 
			
		||||
        time = datetime.utcfromtimestamp(record.created).strftime('%H:%M:%S')
 | 
			
		||||
        if record.args:
 | 
			
		||||
            tuple_args = (record.args,) if isinstance(record.args, dict) else record.args
 | 
			
		||||
            clean_args = tuple((self.trim(arg) for arg in tuple_args))
 | 
			
		||||
            message = record.msg % clean_args
 | 
			
		||||
        else:
 | 
			
		||||
            message = record.msg
 | 
			
		||||
        trace = '\n'+''.join(format_exception(*record.exc_info)) if record.exc_info else ''
 | 
			
		||||
 | 
			
		||||
        return (f'[{Color.WHITE}{time}{Color.RESET}|>'
 | 
			
		||||
                f'{self.severity_to_color[record.levelname]}{record.module}:'
 | 
			
		||||
                f'{record.levelname.lower()}{Color.RESET}] {message}{trace}')
 | 
			
		||||
 | 
			
		||||
    def trim(self, value):
 | 
			
		||||
        if isinstance(value, dict):
 | 
			
		||||
            return {k: self.trim(v) for k, v in value.items()}
 | 
			
		||||
        if isinstance(value, str):
 | 
			
		||||
            value_str = str(value)
 | 
			
		||||
            return value_str if len(value_str) <= self.limit else f'{value_str[:self.limit]}...'
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class VerboseLogFormatter(Formatter):
 | 
			
		||||
    def format(self, record):  # pylint: disable=no-self-use
 | 
			
		||||
        date = datetime.utcfromtimestamp(record.created).strftime('%H:%M:%S')
 | 
			
		||||
        if record.args:
 | 
			
		||||
            message = record.msg % record.args
 | 
			
		||||
        else:
 | 
			
		||||
            message = record.msg
 | 
			
		||||
        trace = '\n'+''.join(format_exception(*record.exc_info)) if record.exc_info else ''
 | 
			
		||||
 | 
			
		||||
        return (f'[{date}|>{record.module}:{record.levelname.lower()}] '
 | 
			
		||||
                f'{message}{trace}')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WhitelistFilter(Filter):
 | 
			
		||||
    def __init__(self, names):
 | 
			
		||||
        self.names = names
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    def filter(self, record):
 | 
			
		||||
        return record.module in self.names
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BlacklistFilter(Filter):
 | 
			
		||||
    def __init__(self, names):
 | 
			
		||||
        self.names = names
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
    def filter(self, record):
 | 
			
		||||
        return record.module not in self.names
 | 
			
		||||
							
								
								
									
										4
									
								
								tfw/main/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								tfw/main/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,4 @@
 | 
			
		||||
from .tfw_connector import TFWUplinkConnector, TFWConnector
 | 
			
		||||
from .event_handler_factory import EventHandlerFactory
 | 
			
		||||
from .signal_handling import setup_signal_handlers
 | 
			
		||||
from .tfw_server import TFWServer
 | 
			
		||||
							
								
								
									
										8
									
								
								tfw/main/event_handler_factory.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								tfw/main/event_handler_factory.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,8 @@
 | 
			
		||||
from tfw.internals.event_handling import EventHandlerFactoryBase
 | 
			
		||||
 | 
			
		||||
from .tfw_connector import TFWConnector
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EventHandlerFactory(EventHandlerFactoryBase):
 | 
			
		||||
    def _build_server_connector(self):
 | 
			
		||||
        return TFWConnector()
 | 
			
		||||
							
								
								
									
										11
									
								
								tfw/main/signal_handling.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								tfw/main/signal_handling.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,11 @@
 | 
			
		||||
from signal import signal, SIGTERM, SIGINT
 | 
			
		||||
 | 
			
		||||
from tfw.internals.event_handling import EventHandler
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def setup_signal_handlers():
 | 
			
		||||
    def stop(*_):
 | 
			
		||||
        EventHandler.stop_all_instances()
 | 
			
		||||
        exit(0)
 | 
			
		||||
    signal(SIGTERM, stop)
 | 
			
		||||
    signal(SIGINT, stop)
 | 
			
		||||
							
								
								
									
										25
									
								
								tfw/main/tfw_connector.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								tfw/main/tfw_connector.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,25 @@
 | 
			
		||||
from tfw.internals.networking import ServerConnector, ServerUplinkConnector
 | 
			
		||||
from tfw.config import TFWENV
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ConnAddrMixin:
 | 
			
		||||
    @property
 | 
			
		||||
    def uplink_conn_addr(self):
 | 
			
		||||
        return f'tcp://localhost:{TFWENV.PULL_PORT}'
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def downlink_conn_addr(self):
 | 
			
		||||
        return f'tcp://localhost:{TFWENV.PUB_PORT}'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TFWUplinkConnector(ServerUplinkConnector, ConnAddrMixin):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        super().__init__(self.uplink_conn_addr)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TFWConnector(ServerConnector, ConnAddrMixin):
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        super().__init__(
 | 
			
		||||
            self.downlink_conn_addr,
 | 
			
		||||
            self.uplink_conn_addr
 | 
			
		||||
        )
 | 
			
		||||
							
								
								
									
										31
									
								
								tfw/main/tfw_server.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								tfw/main/tfw_server.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,31 @@
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
from tornado.web import Application
 | 
			
		||||
 | 
			
		||||
from tfw.internals.networking import EventHandlerConnector
 | 
			
		||||
from tfw.internals.server import ZMQWebSocketRouter
 | 
			
		||||
from tfw.config import TFWENV
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
LOG = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TFWServer:
 | 
			
		||||
    """
 | 
			
		||||
    This class handles the proxying of messages between the frontend and event handers.
 | 
			
		||||
    It proxies messages from the "/ws" route to all event handlers subscribed to a ZMQ
 | 
			
		||||
    SUB socket.
 | 
			
		||||
    """
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self._event_handler_connector = EventHandlerConnector(
 | 
			
		||||
            downlink_bind_addr=f'tcp://*:{TFWENV.PULL_PORT}',
 | 
			
		||||
            uplink_bind_addr=f'tcp://*:{TFWENV.PUB_PORT}'
 | 
			
		||||
        )
 | 
			
		||||
        self.application = Application([(
 | 
			
		||||
            r'/ws', ZMQWebSocketRouter, {
 | 
			
		||||
                'event_handler_connector': self._event_handler_connector,
 | 
			
		||||
            }
 | 
			
		||||
        )])
 | 
			
		||||
 | 
			
		||||
    def listen(self):
 | 
			
		||||
        self.application.listen(TFWENV.WEB_PORT)
 | 
			
		||||
		Reference in New Issue
	
	Block a user