Simplify package structure

This commit is contained in:
Kristóf Tóth
2019-07-24 15:50:41 +02:00
parent a23224aced
commit 52399f413c
79 changed files with 22 additions and 24 deletions

0
tfw/__init__.py Normal file
View File

View File

View File

@ -0,0 +1,2 @@
from .frontend_handler import FrontendHandler
from .message_sender import MessageSender

View File

@ -0,0 +1,25 @@
from tfw.internals.networking import Scope
from .message_storage import FrontendMessageStorage
class FrontendHandler:
keys = ['message', 'queueMessages', 'dashboard', 'console']
def __init__(self):
self.server_connector = None
self.keys = [*type(self).keys, 'recover']
self._frontend_message_storage = FrontendMessageStorage(type(self).keys)
def send_message(self, message):
self.server_connector.send_message(message, scope=Scope.WEBSOCKET)
def handle_event(self, message, _):
self._frontend_message_storage.save_message(message)
if message['key'] == 'recover':
self.recover_frontend()
self.send_message(message)
def recover_frontend(self):
for message in self._frontend_message_storage.messages:
self.send_message(message)

View File

@ -0,0 +1,48 @@
class MessageSender:
"""
Provides mechanisms to send messages to our frontend messaging component.
"""
def __init__(self, uplink):
self.uplink = uplink
self.key = 'message'
self.queue_key = 'queueMessages'
def send(self, originator, message):
"""
Sends a message.
:param originator: name of sender to be displayed on the frontend
:param message: message to send
"""
message = {
'key': self.key,
'data': {
'originator': originator,
'message': message
}
}
self.uplink.send_message(message)
def queue_messages(self, originator, messages):
"""
Queues a list of messages to be displayed in a chatbot-like manner.
:param originator: name of sender to be displayed on the frontend
:param messages: list of messages to queue
"""
message = {
'key': self.queue_key,
'data': {
'messages': [
{'message': message, 'originator': originator}
for message in messages
]
}
}
self.uplink.send_message(message)
@staticmethod
def generate_messages_from_queue(queue_message):
for message in queue_message['data']['messages']:
yield {
'key': 'message',
'data': message
}

View File

@ -0,0 +1,44 @@
from abc import ABC, abstractmethod
from contextlib import suppress
from .message_sender import MessageSender
class MessageStorage(ABC):
def __init__(self):
self._messages = []
def save_message(self, message):
with suppress(KeyError, AttributeError):
if self._filter_message(message):
self._messages.extend(self._transform_message(message))
@abstractmethod
def _filter_message(self, message):
raise NotImplementedError
def _transform_message(self, message): # pylint: disable=no-self-use
yield message
def clear(self):
self._messages.clear()
@property
def messages(self):
yield from self._messages
class FrontendMessageStorage(MessageStorage):
def __init__(self, keys):
self._keys = keys
super().__init__()
def _filter_message(self, message):
key = message['key']
return key in self._keys
def _transform_message(self, message):
if message['key'] == 'queueMessages':
yield from MessageSender.generate_messages_from_queue(message)
else:
yield message

View File

@ -0,0 +1 @@
from .fsm_handler import FSMHandler

View File

@ -0,0 +1,71 @@
import logging
from tfw.internals.crypto import KeyManager, sign_message, verify_message
from tfw.internals.networking import Scope
from .fsm_updater import FSMUpdater
LOG = logging.getLogger(__name__)
class FSMHandler:
keys = ['fsm']
"""
EventHandler responsible for managing the state machine of
the framework (TFW FSM).
tfw.networking.TFWServer instances automatically send 'trigger'
commands to the event handler listening on the 'fsm' key,
which should be an instance of this event handler.
This event handler accepts messages that have a
data['command'] key specifying a command to be executed.
An 'fsm_update' message is broadcasted after every successful
command.
"""
def __init__(self, *, fsm_type, require_signature=False):
self.fsm = fsm_type()
self._fsm_updater = FSMUpdater(self.fsm)
self.auth_key = KeyManager().auth_key
self._require_signature = require_signature
self.command_handlers = {
'trigger': self.handle_trigger,
'update': self.handle_update
}
def handle_event(self, message, server_connector):
try:
message = self.command_handlers[message['data']['command']](message)
if message:
fsm_update_message = self._fsm_updater.fsm_update
sign_message(self.auth_key, message)
sign_message(self.auth_key, fsm_update_message)
server_connector.send_message(fsm_update_message, Scope.BROADCAST)
except KeyError:
LOG.error('IGNORING MESSAGE: Invalid message received: %s', message)
def handle_trigger(self, message):
"""
Attempts to step the FSM with the supplied trigger.
:param message: TFW message with a data field containing
the action to try triggering in data['value']
"""
trigger = message['data']['value']
if self._require_signature:
if not verify_message(self.auth_key, message):
LOG.error('Ignoring unsigned trigger command: %s', message)
return None
if self.fsm.step(trigger):
return message
return None
def handle_update(self, message):
"""
Does nothing, but triggers an 'fsm_update' message.
"""
# pylint: disable=no-self-use
return message

View File

@ -0,0 +1,25 @@
class FSMUpdater:
def __init__(self, fsm):
self.fsm = fsm
@property
def fsm_update(self):
return {
'key': 'fsm_update',
**self.fsm_update_data
}
@property
def fsm_update_data(self):
valid_transitions = [
{'trigger': trigger}
for trigger in self.fsm.get_triggers(self.fsm.state)
]
last_fsm_event = self.fsm.event_log[-1]
last_fsm_event['timestamp'] = last_fsm_event['timestamp'].isoformat()
return {
'current_state': self.fsm.state,
'valid_transitions': valid_transitions,
'in_accepted_state': self.fsm.in_accepted_state,
'last_event': last_fsm_event
}

View File

@ -0,0 +1 @@
from .ide_handler import IdeHandler

View File

@ -0,0 +1 @@
from .file_manager import FileManager

View File

@ -0,0 +1,93 @@
from typing import Iterable
from glob import glob
from fnmatch import fnmatchcase
from os.path import basename, isfile, join, relpath, exists, isdir, realpath
class FileManager: # pylint: disable=too-many-instance-attributes
def __init__(self, working_directory, allowed_directories, selected_file=None, exclude=None):
self._exclude, self.exclude = [], exclude
self._allowed_directories, self.allowed_directories = None, allowed_directories
self._workdir, self.workdir = None, working_directory
self._filename, self.filename = None, selected_file or self.files[0]
@property
def exclude(self):
return self._exclude
@exclude.setter
def exclude(self, exclude):
if exclude is None:
return
if not isinstance(exclude, Iterable):
raise TypeError('Exclude must be Iterable!')
self._exclude = exclude
@property
def workdir(self):
return self._workdir
@workdir.setter
def workdir(self, directory):
if not exists(directory) or not isdir(directory):
raise EnvironmentError(f'"{directory}" is not a directory!')
if not self._is_in_allowed_dir(directory):
raise EnvironmentError(f'Directory "{directory}" is not allowed!')
self._workdir = directory
@property
def allowed_directories(self):
return self._allowed_directories
@allowed_directories.setter
def allowed_directories(self, directories):
self._allowed_directories = [realpath(directory) for directory in directories]
@property
def filename(self):
return self._filename
@filename.setter
def filename(self, filename):
if filename not in self.files:
raise EnvironmentError('No such file in workdir!')
self._filename = filename
@property
def files(self):
return [
self._relpath(file)
for file in glob(join(self._workdir, '**/*'), recursive=True)
if isfile(file)
and self._is_in_allowed_dir(file)
and not self._is_blacklisted(file)
]
@property
def file_contents(self):
with open(self._filepath(self.filename), 'rb', buffering=0) as ifile:
return ifile.read().decode(errors='surrogateescape')
@file_contents.setter
def file_contents(self, value):
with open(self._filepath(self.filename), 'wb', buffering=0) as ofile:
ofile.write(value.encode())
def _is_in_allowed_dir(self, path):
return any(
realpath(path).startswith(allowed_dir)
for allowed_dir in self.allowed_directories
)
def _is_blacklisted(self, file):
return any(
fnmatchcase(file, blacklisted) or
fnmatchcase(basename(file), blacklisted)
for blacklisted in self.exclude
)
def _filepath(self, filename):
return join(self._workdir, filename)
def _relpath(self, filename):
return relpath(self._filepath(filename), start=self._workdir)

View File

@ -0,0 +1,124 @@
# pylint: disable=redefined-outer-name
from dataclasses import dataclass
from secrets import token_urlsafe
from os.path import join
from os import chdir, mkdir, symlink
from pathlib import Path
from tempfile import TemporaryDirectory
import pytest
from .file_manager import FileManager
@dataclass
class ManagerContext:
folder: str
manager: FileManager
def join(self, path):
return join(self.folder, path)
@pytest.fixture()
def context():
dirs = {}
with TemporaryDirectory() as workdir:
chdir(workdir)
for name in ['allowed', 'excluded', 'invis']:
node = join(workdir, name)
mkdir(node)
Path(join(node, 'empty.txt')).touch()
Path(join(node, 'empty.bin')).touch()
dirs[name] = node
yield ManagerContext(
workdir,
FileManager(
dirs['allowed'],
[dirs['allowed'], dirs['excluded']],
exclude=['*/excluded/*']
)
)
@pytest.mark.parametrize('subdir', ['allowed/', 'excluded/'])
def test_select_allowed_dirs(context, subdir):
context.manager.workdir = context.join(subdir)
assert context.manager.workdir == context.join(subdir)
newdir = context.join(subdir+'deep')
mkdir(newdir)
context.manager.workdir = newdir
assert context.manager.workdir == newdir
@pytest.mark.parametrize('invdir', ['', 'invis'])
def test_select_forbidden_dirs(context, invdir):
fullpath = context.join(invdir)
with pytest.raises(OSError):
context.manager.workdir = fullpath
assert context.manager.workdir != fullpath
context.manager.allowed_directories += [fullpath]
context.manager.workdir = fullpath
assert context.manager.workdir == fullpath
@pytest.mark.parametrize('filename', ['another.txt', '*.txt'])
def test_select_allowed_files(context, filename):
Path(context.join('allowed/'+filename)).touch()
assert filename in context.manager.files
context.manager.filename = filename
assert context.manager.filename == filename
@pytest.mark.parametrize('path', [
{'dir': 'allowed/', 'file': 'illegal.bin'},
{'dir': 'excluded/', 'file': 'legal.txt'},
{'dir': 'allowed/', 'file': token_urlsafe(16)+'.bin'},
{'dir': 'excluded/', 'file': token_urlsafe(16)+'.txt'},
{'dir': 'allowed/', 'file': token_urlsafe(32)+'.bin'},
{'dir': 'excluded/', 'file': token_urlsafe(32)+'.txt'}
])
def test_select_excluded_files(context, path):
context.manager.workdir = context.join(path['dir'])
context.manager.exclude = ['*/excluded/*', '*.bin']
Path(context.join(path['dir']+path['file'])).touch()
assert path['file'] not in context.manager.files
with pytest.raises(OSError):
context.manager.filename = path['file']
@pytest.mark.parametrize('path', [
{'src': 'excluded/empty.txt', 'dst': 'allowed/link.txt'},
{'src': 'invis/empty.txt', 'dst': 'allowed/link.txt'},
{'src': 'excluded/empty.txt', 'dst': 'allowed/'+token_urlsafe(16)+'.txt'},
{'src': 'invis/empty.txt', 'dst': 'allowed/'+token_urlsafe(16)+'.txt'},
{'src': 'excluded/empty.txt', 'dst': 'allowed/'+token_urlsafe(32)+'.txt'},
{'src': 'invis/empty.txt', 'dst': 'allowed/'+token_urlsafe(32)+'.txt'}
])
def test_select_excluded_symlinks(context, path):
symlink(context.join(path['src']), context.join(path['dst']))
assert path['dst'] not in context.manager.files
def test_read_write_file(context):
for _ in range(128):
context.manager.filename = 'empty.txt'
content = token_urlsafe(32)
context.manager.file_contents = content
assert context.manager.file_contents == content
with open(context.join('allowed/empty.txt'), 'r') as ifile:
assert ifile.read() == content
def test_regular_ide_actions(context):
context.manager.workdir = context.join('allowed')
newfile1, newfile2 = token_urlsafe(16), token_urlsafe(16)
Path(context.join(f'allowed/{newfile1}')).touch()
Path(context.join(f'allowed/{newfile2}')).touch()
for _ in range(8):
context.manager.filename = newfile1
content1 = token_urlsafe(32)
context.manager.file_contents = content1
context.manager.filename = newfile2
content2 = token_urlsafe(32)
context.manager.file_contents = content2
context.manager.filename = newfile1
assert context.manager.file_contents == content1
context.manager.filename = newfile2
assert context.manager.file_contents == content2

View File

@ -0,0 +1,196 @@
import logging
from tfw.internals.networking import Scope
from tfw.internals.inotify import InotifyObserver
from .file_manager import FileManager
LOG = logging.getLogger(__name__)
BUILD_ARTIFACTS = (
"*.a",
"*.class",
"*.dll",
"*.dylib",
"*.elf",
"*.exe",
"*.jar",
"*.ko",
"*.la",
"*.lib",
"*.lo",
"*.o",
"*.obj",
"*.out",
"*.py[cod]",
"*.so",
"*.so.*",
"*.tar.gz",
"*.zip",
"*__pycache__*"
)
class IdeHandler:
keys = ['ide']
# pylint: disable=too-many-arguments,anomalous-backslash-in-string
"""
Event handler implementing the backend of our browser based IDE.
By default all files in the directory specified in __init__ are displayed
on the fontend. Note that this is a stateful component.
When any file in the selected directory changes they are automatically refreshed
on the frontend (this is done by listening to inotify events).
This EventHandler accepts messages that have a data['command'] key specifying
a command to be executed.
The API of each command is documented in their respective handler.
"""
def __init__(self, *, directory, allowed_directories, selected_file=None, exclude=None):
"""
:param key: the key this instance should listen to
:param directory: working directory which the EventHandler should serve files from
:param allowed_directories: list of directories that can be switched to using selectdir
:param selected_file: file that is selected by default
:param exclude: list of filenames that should not appear between files (for .o, .pyc, etc.)
"""
self.server_connector = None
try:
self.filemanager = FileManager(
allowed_directories=allowed_directories,
working_directory=directory,
selected_file=selected_file,
exclude=exclude
)
except IndexError:
raise EnvironmentError(
f'No file(s) in IdeEventHandler working_directory "{directory}"!'
)
self.monitor = InotifyObserver(
self.filemanager.allowed_directories,
exclude=BUILD_ARTIFACTS
)
self.monitor.on_modified = self._reload_frontend
self.monitor.start()
self.commands = {
'read': self.read,
'write': self.write,
'select': self.select,
'selectdir': self.select_dir,
'exclude': self.exclude
}
def _reload_frontend(self, event): # pylint: disable=unused-argument
self.send_message({
'key': 'ide',
'data': {'command': 'reload'}
})
def send_message(self, message):
self.server_connector.send_message(message, scope=Scope.WEBSOCKET)
def read(self, data):
"""
Read the currently selected file.
:return dict: TFW message data containing key 'content'
(contents of the selected file)
"""
try:
data['content'] = self.filemanager.file_contents
except PermissionError:
data['content'] = 'You have no permission to open that file :('
except FileNotFoundError:
data['content'] = 'This file was removed :('
except Exception: # pylint: disable=broad-except
data['content'] = 'Failed to read file :('
return data
def write(self, data):
"""
Overwrites a file with the desired string.
:param data: TFW message data containing key 'content'
(new file content)
"""
try:
self.filemanager.file_contents = data['content']
except Exception: # pylint: disable=broad-except
LOG.exception('Error writing file!')
del data['content']
return data
def select(self, data):
"""
Selects a file from the current directory.
:param data: TFW message data containing 'filename'
(name of file to select relative to the current directory)
"""
try:
self.filemanager.filename = data['filename']
except EnvironmentError:
LOG.exception('Failed to select file "%s"', data['filename'])
return data
def select_dir(self, data):
"""
Select a new working directory to display files from.
:param data: TFW message data containing 'directory'
(absolute path of diretory to select.
must be a path whitelisted in
self.allowed_directories)
"""
try:
self.filemanager.workdir = data['directory']
try:
self.filemanager.filename = self.filemanager.files[0]
self.read(data)
except IndexError:
data['content'] = 'No files in this directory :('
except EnvironmentError as err:
LOG.error(
'Failed to select directory "%s". Reason: %s',
data['directory'], str(err)
)
return data
def exclude(self, data):
"""
Overwrite list of excluded files
:param data: TFW message data containing 'exclude'
(list of unix-style filename patterns to be excluded,
e.g.: ["\*.pyc", "\*.o")
"""
try:
self.filemanager.exclude = list(data['exclude'])
except TypeError:
LOG.error('Exclude must be Iterable!')
return data
def attach_fileinfo(self, data):
"""
Basic information included in every response to the frontend.
"""
data['filename'] = self.filemanager.filename
data['files'] = self.filemanager.files
data['directory'] = self.filemanager.workdir
def handle_event(self, message, _):
try:
data = message['data']
message['data'] = self.commands[data['command']](data)
self.attach_fileinfo(data)
self.send_message(message)
except KeyError:
LOG.error('IGNORING MESSAGE: Invalid message received: %s', message)
def cleanup(self):
self.monitor.stop()

View File

@ -0,0 +1 @@
from .pipe_io_handler import PipeIOHandler, PipeIOHandlerBase, TransformerPipeIOHandler, CommandHandler

View File

@ -0,0 +1,143 @@
import logging
from abc import abstractmethod
from json import loads, dumps
from subprocess import run, PIPE, Popen
from functools import partial
from os import getpgid, killpg
from os.path import join
from signal import SIGTERM
from secrets import token_urlsafe
from threading import Thread
from contextlib import suppress
from .pipe_io_server import PipeIOServer, terminate_process_on_failure
LOG = logging.getLogger(__name__)
DEFAULT_PERMISSIONS = 0o600
class PipeIOHandlerBase:
keys = ['']
def __init__(self, in_pipe_path, out_pipe_path, permissions=DEFAULT_PERMISSIONS):
self.server_connector = None
self.pipe_io = CallbackPipeIOServer(
in_pipe_path,
out_pipe_path,
self.handle_pipe_event,
permissions
)
self.pipe_io.start()
@abstractmethod
def handle_pipe_event(self, message_bytes):
raise NotImplementedError()
def cleanup(self):
self.pipe_io.stop()
class CallbackPipeIOServer(PipeIOServer):
def __init__(self, in_pipe_path, out_pipe_path, callback, permissions):
super().__init__(in_pipe_path, out_pipe_path, permissions)
self.callback = callback
def handle_message(self, message):
try:
self.callback(message)
except: # pylint: disable=bare-except
LOG.exception('Failed to handle message %s from pipe %s!', message, self.in_pipe)
class PipeIOHandler(PipeIOHandlerBase):
def handle_event(self, message, _):
json_bytes = dumps(message).encode()
self.pipe_io.send_message(json_bytes)
def handle_pipe_event(self, message_bytes):
json = loads(message_bytes)
self.server_connector.send_message(json)
class TransformerPipeIOHandler(PipeIOHandlerBase):
# pylint: disable=too-many-arguments
def __init__(
self, in_pipe_path, out_pipe_path,
transform_in_cmd, transform_out_cmd,
permissions=DEFAULT_PERMISSIONS
):
self._transform_in = partial(self._transform_message, transform_in_cmd)
self._transform_out = partial(self._transform_message, transform_out_cmd)
super().__init__(in_pipe_path, out_pipe_path, permissions)
@staticmethod
def _transform_message(transform_cmd, message):
proc = run(
transform_cmd,
input=message,
stdout=PIPE,
stderr=PIPE,
shell=True
)
if proc.returncode == 0:
return proc.stdout
raise ValueError(f'Transforming message {message} failed!')
def handle_event(self, message, _):
json_bytes = dumps(message).encode()
transformed_bytes = self._transform_out(json_bytes)
if transformed_bytes:
self.pipe_io.send_message(transformed_bytes)
def handle_pipe_event(self, message_bytes):
transformed_bytes = self._transform_in(message_bytes)
if transformed_bytes:
json_message = loads(transformed_bytes)
self.server_connector.send_message(json_message)
class CommandHandler(PipeIOHandler):
def __init__(self, command, permissions=DEFAULT_PERMISSIONS):
super().__init__(
self._generate_tempfilename(),
self._generate_tempfilename(),
permissions
)
self._proc_stdin = open(self.pipe_io.out_pipe, 'rb')
self._proc_stdout = open(self.pipe_io.in_pipe, 'wb')
self._proc = Popen(
command, shell=True, executable='/bin/bash',
stdin=self._proc_stdin, stdout=self._proc_stdout, stderr=PIPE,
start_new_session=True
)
self._monitor_proc_thread = self._start_monitor_proc()
def _generate_tempfilename(self):
# pylint: disable=no-self-use
random_filename = partial(token_urlsafe, 10)
return join('/tmp', f'{type(self).__name__}.{random_filename()}')
def _start_monitor_proc(self):
thread = Thread(target=self._monitor_proc, daemon=True)
thread.start()
return thread
@terminate_process_on_failure
def _monitor_proc(self):
return_code = self._proc.wait()
if return_code == -int(SIGTERM):
# supervisord asked the program to terminate, this is fine
return
if return_code != 0:
_, stderr = self._proc.communicate()
raise RuntimeError(f'Subprocess failed ({return_code})! Stderr:\n{stderr.decode()}')
def cleanup(self):
with suppress(ProcessLookupError):
process_group_id = getpgid(self._proc.pid)
killpg(process_group_id, SIGTERM)
self._proc_stdin.close()
self._proc_stdout.close()
super().cleanup()

View File

@ -0,0 +1,2 @@
from .pipe_io_server import PipeIOServer
from .terminate_process_on_failure import terminate_process_on_failure

View File

@ -0,0 +1,27 @@
from collections import deque
from threading import Lock, Condition
class Deque:
def __init__(self):
self._queue = deque()
self._mutex = Lock()
self._not_empty = Condition(self._mutex)
def pop(self):
with self._mutex:
while not self._queue:
self._not_empty.wait()
return self._queue.pop()
def push(self, item):
self._push(item, self._queue.appendleft)
def push_front(self, item):
self._push(item, self._queue.append)
def _push(self, item, put_method):
with self._mutex:
put_method(item)
self._not_empty.notify()

View File

@ -0,0 +1,16 @@
from os import mkfifo, remove, chmod
from os.path import exists
class Pipe:
def __init__(self, path):
self.path = path
def recreate(self, permissions):
self.remove()
mkfifo(self.path)
chmod(self.path, permissions) # use chmod to ignore umask
def remove(self):
if exists(self.path):
remove(self.path)

View File

@ -0,0 +1,73 @@
from abc import ABC, abstractmethod
from threading import Thread, Event
from typing import Callable
from .pipe_reader_thread import PipeReaderThread
from .pipe_writer_thread import PipeWriterThread
from .pipe import Pipe
from .terminate_process_on_failure import terminate_process_on_failure
class PipeIOServer(ABC, Thread):
def __init__(self, in_pipe=None, out_pipe=None, permissions=0o600):
super().__init__(daemon=True)
self._in_pipe, self._out_pipe = in_pipe, out_pipe
self._create_pipes(permissions)
self._stop_event = Event()
self._reader_thread, self._writer_thread = self._create_io_threads()
self._io_threads = (self._reader_thread, self._writer_thread)
self._on_stop = lambda: None
def _create_pipes(self, permissions):
Pipe(self.in_pipe).recreate(permissions)
Pipe(self.out_pipe).recreate(permissions)
@property
def in_pipe(self):
return self._in_pipe
@property
def out_pipe(self):
return self._out_pipe
def _create_io_threads(self):
reader_thread = PipeReaderThread(self.in_pipe, self._stop_event, self.handle_message)
writer_thread = PipeWriterThread(self.out_pipe, self._stop_event)
return reader_thread, writer_thread
@abstractmethod
def handle_message(self, message):
raise NotImplementedError()
def send_message(self, message):
self._writer_thread.write(message)
@terminate_process_on_failure
def run(self):
for thread in self._io_threads:
thread.start()
self._stop_event.wait()
self._stop_threads()
def stop(self):
self._stop_event.set()
if self.is_alive():
self.join()
def _stop_threads(self):
for thread in self._io_threads:
if thread.is_alive():
thread.stop()
Pipe(self.in_pipe).remove()
Pipe(self.out_pipe).remove()
self._on_stop()
def _set_on_stop(self, value):
if not isinstance(value, Callable):
raise ValueError("Supplied object is not callable!")
self._on_stop = value
on_stop = property(fset=_set_on_stop)
def wait(self):
self._stop_event.wait()

View File

@ -0,0 +1,44 @@
from contextlib import suppress
from os import open as osopen
from os import write, close, O_WRONLY, O_NONBLOCK
from threading import Thread
from .terminate_process_on_failure import terminate_process_on_failure
class PipeReaderThread(Thread):
eof = b''
stop_sequence = b'stop_reading\n'
def __init__(self, pipe_path, stop_event, message_handler):
super().__init__(daemon=True)
self._message_handler = message_handler
self._pipe_path = pipe_path
self._stop_event = stop_event
@terminate_process_on_failure
def run(self):
with self._open() as pipe:
while True:
message = pipe.readline()
if message == self.stop_sequence:
self._stop_event.set()
break
if message == self.eof:
self._open().close()
continue
self._message_handler(message[:-1])
def _open(self):
return open(self._pipe_path, 'rb')
def stop(self):
while self.is_alive():
self._unblock()
self.join()
def _unblock(self):
with suppress(OSError):
fd = osopen(self._pipe_path, O_WRONLY | O_NONBLOCK)
write(fd, self.stop_sequence)
close(fd)

View File

@ -0,0 +1,50 @@
from contextlib import suppress
from os import O_NONBLOCK, O_RDONLY, close
from os import open as osopen
from threading import Thread
from .terminate_process_on_failure import terminate_process_on_failure
from .deque import Deque
class PipeWriterThread(Thread):
def __init__(self, pipe_path, stop_event):
super().__init__(daemon=True)
self._pipe_path = pipe_path
self._stop_event = stop_event
self._write_queue = Deque()
def write(self, message):
self._write_queue.push(message)
@terminate_process_on_failure
def run(self):
with self._open() as pipe:
while True:
message = self._write_queue.pop()
if message is None:
self._stop_event.set()
break
try:
pipe.write(message + b'\n')
pipe.flush()
except BrokenPipeError:
try: # pipe was reopened, close() flushed the message
pipe.close()
except BrokenPipeError: # close() discarded the message
self._write_queue.push_front(message)
pipe = self._open()
def _open(self):
return open(self._pipe_path, 'wb')
def stop(self):
while self.is_alive():
self._unblock()
self.join()
def _unblock(self):
with suppress(OSError):
fd = osopen(self._pipe_path, O_RDONLY | O_NONBLOCK)
self._write_queue.push_front(None)
close(fd)

View File

@ -0,0 +1,15 @@
from functools import wraps
from os import kill, getpid
from signal import SIGTERM
from traceback import print_exc
def terminate_process_on_failure(fun):
@wraps(fun)
def wrapper(*args, **kwargs):
try:
return fun(*args, **kwargs)
except: # pylint: disable=bare-except
print_exc()
kill(getpid(), SIGTERM)
return wrapper

View File

@ -0,0 +1,2 @@
from .process_handler import ProcessHandler
from .process_log_handler import ProcessLogHandler

View File

@ -0,0 +1,45 @@
import logging
from tfw.internals.networking import Scope
from tfw.internals.inotify import InotifyObserver
from .supervisor import ProcessLogManager
class LogInotifyObserver(InotifyObserver, ProcessLogManager):
def __init__(self, server_connector, supervisor_uri, process_name, log_tail=0):
self._prevent_log_recursion()
self._server_connector = server_connector
self._process_name = process_name
self.log_tail = log_tail
self._procinfo = None
ProcessLogManager.__init__(self, supervisor_uri)
InotifyObserver.__init__(self, self._get_logfiles())
@staticmethod
def _prevent_log_recursion():
# This is done to prevent inotify event logs triggering themselves (infinite log recursion)
logging.getLogger('watchdog.observers.inotify_buffer').propagate = False
def _get_logfiles(self):
self._procinfo = self.supervisor.getProcessInfo(self._process_name)
return self._procinfo['stdout_logfile'], self._procinfo['stderr_logfile']
@property
def process_name(self):
return self._process_name
@process_name.setter
def process_name(self, process_name):
self._process_name = process_name
self.paths = self._get_logfiles()
def on_modified(self, event):
self._server_connector.send_message({
'key': 'processlog',
'data': {
'command': 'new_log',
'stdout': self.read_stdout(self.process_name, tail=self.log_tail),
'stderr': self.read_stderr(self.process_name, tail=self.log_tail)
}
}, Scope.BROADCAST)

View File

@ -0,0 +1,54 @@
import logging
from xmlrpc.client import Fault as SupervisorFault
from tfw.internals.networking import Scope
from .supervisor import ProcessManager, ProcessLogManager
LOG = logging.getLogger(__name__)
class ProcessHandler(ProcessManager, ProcessLogManager):
keys = ['processmanager']
"""
Event handler that can manage processes managed by supervisor.
This EventHandler accepts messages that have a data['command'] key specifying
a command to be executed.
Every message must contain a data['process_name'] field with the name of the
process to manage. This is the name specified in supervisor config files like so:
[program:someprogram]
Commands available: start, stop, restart, readlog
(the names are as self-documenting as it gets)
"""
def __init__(self, *, supervisor_uri, log_tail=0):
ProcessManager.__init__(self, supervisor_uri)
ProcessLogManager.__init__(self, supervisor_uri)
self.log_tail = log_tail
self.commands = {
'start': self.start_process,
'stop': self.stop_process,
'restart': self.restart_process
}
def handle_event(self, message, server_connector):
try:
data = message['data']
try:
self.commands[data['command']](data['process_name'])
except SupervisorFault as fault:
message['data']['error'] = fault.faultString
finally:
message['data']['stdout'] = self.read_stdout(
data['process_name'],
self.log_tail
)
message['data']['stderr'] = self.read_stderr(
data['process_name'],
self.log_tail
)
server_connector.send_message(message, scope=Scope.WEBSOCKET)
except KeyError:
LOG.error('IGNORING MESSAGE: Invalid message received: %s', message)

View File

@ -0,0 +1,69 @@
import logging
from .log_inotify_observer import LogInotifyObserver
LOG = logging.getLogger(__name__)
class ProcessLogHandler:
keys = ['logmonitor']
"""
Monitors the output of a supervisor process (stdout, stderr) and
sends the results to the frontend.
Accepts messages that have a data['command'] key specifying
a command to be executed.
The API of each command is documented in their respective handler.
"""
def __init__(self, *, process_name, supervisor_uri, log_tail=0):
self.server_connector = None
self.process_name = process_name
self._supervisor_uri = supervisor_uri
self._initial_log_tail = log_tail
self._monitor = None
self.command_handlers = {
'process_name': self.handle_process_name,
'log_tail': self.handle_log_tail
}
def start(self):
self._monitor = LogInotifyObserver(
server_connector=self.server_connector,
supervisor_uri=self._supervisor_uri,
process_name=self.process_name,
log_tail=self._initial_log_tail
)
self._monitor.start()
def handle_event(self, message, _):
try:
data = message['data']
self.command_handlers[data['command']](data)
except KeyError:
LOG.error('IGNORING MESSAGE: Invalid message received: %s', message)
def handle_process_name(self, data):
"""
Changes the monitored process.
:param data: TFW message data containing 'value'
(name of the process to monitor)
"""
self._monitor.process_name = data['value']
def handle_log_tail(self, data):
"""
Sets tail length of the log the monitor will send
to the frontend (the monitor will send back the last
'value' characters of the log).
:param data: TFW message data containing 'value'
(new tail length)
"""
self._monitor.log_tail = data['value']
def cleanup(self):
self._monitor.stop()

View File

@ -0,0 +1,36 @@
from os import remove
from contextlib import suppress
import xmlrpc.client
from xmlrpc.client import Fault as SupervisorFault
class SupervisorBase:
def __init__(self, supervisor_uri):
self.supervisor = xmlrpc.client.ServerProxy(supervisor_uri).supervisor
class ProcessManager(SupervisorBase):
def stop_process(self, process_name):
with suppress(SupervisorFault):
self.supervisor.stopProcess(process_name)
def start_process(self, process_name):
self.supervisor.startProcess(process_name)
def restart_process(self, process_name):
self.stop_process(process_name)
self.start_process(process_name)
class ProcessLogManager(SupervisorBase):
def read_stdout(self, process_name, tail=0):
return self.supervisor.readProcessStdoutLog(process_name, -tail, 0)
def read_stderr(self, process_name, tail=0):
return self.supervisor.readProcessStderrLog(process_name, -tail, 0)
def clear_logs(self, process_name):
for logfile in ('stdout_logfile', 'stderr_logfile'):
with suppress(FileNotFoundError):
remove(self.supervisor.getProcessInfo(process_name)[logfile])
self.supervisor.clearProcessLogs(process_name)

View File

@ -0,0 +1 @@
from .snapshot_handler import SnapshotHandler

View File

@ -0,0 +1,86 @@
import logging
from os.path import join as joinpath
from os.path import basename
from os import makedirs
from datetime import datetime
from dateutil import parser as dateparser
from tfw.internals.networking import Scope
from .snapshot_provider import SnapshotProvider
LOG = logging.getLogger(__name__)
class SnapshotHandler:
keys = ['snapshot']
def __init__(self, *, directories, snapshots_dir, exclude_unix_patterns=None):
self._snapshots_dir = snapshots_dir
self.snapshot_providers = {}
self._exclude_unix_patterns = exclude_unix_patterns
self.init_snapshot_providers(directories)
self.command_handlers = {
'take_snapshot': self.handle_take_snapshot,
'restore_snapshot': self.handle_restore_snapshot,
'exclude': self.handle_exclude
}
def init_snapshot_providers(self, directories):
for index, directory in enumerate(directories):
git_dir = self.init_git_dir(index, directory)
self.snapshot_providers[directory] = SnapshotProvider(
directory,
git_dir,
self._exclude_unix_patterns
)
def init_git_dir(self, index, directory):
git_dir = joinpath(
self._snapshots_dir,
f'{basename(directory)}-{index}'
)
makedirs(git_dir, exist_ok=True)
return git_dir
def handle_event(self, message, server_connector):
try:
data = message['data']
message['data'] = self.command_handlers[data['command']](data)
server_connector.send_message(message, scope=Scope.WEBSOCKET)
except KeyError:
LOG.error('IGNORING MESSAGE: Invalid message received: %s', message)
def handle_take_snapshot(self, data):
LOG.debug('Taking snapshots of directories %s', self.snapshot_providers.keys())
for provider in self.snapshot_providers.values():
provider.take_snapshot()
return data
def handle_restore_snapshot(self, data):
date = dateparser.parse(
data.get(
'value',
datetime.now().isoformat()
)
)
LOG.debug(
'Restoring snapshots (@ %s) of directories %s',
date,
self.snapshot_providers.keys()
)
for provider in self.snapshot_providers.values():
provider.restore_snapshot(date)
return data
def handle_exclude(self, data):
exclude_unix_patterns = data['value']
if not isinstance(exclude_unix_patterns, list):
raise KeyError
for provider in self.snapshot_providers.values():
provider.exclude = exclude_unix_patterns
return data

View File

@ -0,0 +1,221 @@
import re
from subprocess import run, CalledProcessError, PIPE
from getpass import getuser
from os.path import isdir
from os.path import join as joinpath
from uuid import uuid4
from dateutil import parser as dateparser
class SnapshotProvider:
def __init__(self, directory, git_dir, exclude_unix_patterns=None):
self._classname = self.__class__.__name__
author = f'{getuser()} via TFW {self._classname}'
self.gitenv = {
'GIT_DIR': git_dir,
'GIT_WORK_TREE': directory,
'GIT_AUTHOR_NAME': author,
'GIT_AUTHOR_EMAIL': '',
'GIT_COMMITTER_NAME': author,
'GIT_COMMITTER_EMAIL': '',
'GIT_PAGER': 'cat'
}
self._init_repo()
self.__last_valid_branch = self._branch
if exclude_unix_patterns:
self.exclude = exclude_unix_patterns
def _init_repo(self):
self._check_environment()
if not self._repo_is_initialized:
self._run(('git', 'init'))
if self._number_of_commits == 0:
try:
self._snapshot()
except CalledProcessError:
raise EnvironmentError(f'{self._classname} cannot init on empty directories!')
self._check_head_not_detached()
def _check_environment(self):
if not isdir(self.gitenv['GIT_DIR']) or not isdir(self.gitenv['GIT_WORK_TREE']):
raise EnvironmentError(f'{self._classname}: "directory" and "git_dir" must exist!')
@property
def _repo_is_initialized(self):
return self._run(
('git', 'status'),
check=False
).returncode == 0
@property
def _number_of_commits(self):
return int(
self._get_stdout((
'git', 'rev-list',
'--all',
'--count'
))
)
def _snapshot(self):
self._run((
'git', 'add',
'-A'
))
try:
self._get_stdout((
'git', 'commit',
'-m', 'Snapshot'
))
except CalledProcessError as err:
if b'nothing to commit, working tree clean' not in err.output:
raise
def _check_head_not_detached(self):
if self._head_detached:
raise EnvironmentError(f'{self._classname} cannot init from detached HEAD state!')
@property
def _head_detached(self):
return self._branch == 'HEAD'
@property
def _branch(self):
return self._get_stdout((
'git', 'rev-parse',
'--abbrev-ref', 'HEAD'
))
def _get_stdout(self, *args, **kwargs):
kwargs['stdout'] = PIPE
kwargs['stderr'] = PIPE
stdout_bytes = self._run(*args, **kwargs).stdout
return stdout_bytes.decode().rstrip('\n')
def _run(self, *args, **kwargs):
if 'check' not in kwargs:
kwargs['check'] = True
if 'env' not in kwargs:
kwargs['env'] = self.gitenv
return run(*args, **kwargs)
@property
def exclude(self):
with open(self._exclude_path, 'r') as ofile:
return ofile.read()
@exclude.setter
def exclude(self, exclude_patterns):
with open(self._exclude_path, 'w') as ifile:
ifile.write('\n'.join(exclude_patterns))
@property
def _exclude_path(self):
return joinpath(
self.gitenv['GIT_DIR'],
'info',
'exclude'
)
def take_snapshot(self):
if self._head_detached:
self._checkout_new_branch_from_head()
self._snapshot()
def _checkout_new_branch_from_head(self):
branch_name = str(uuid4())
self._run((
'git', 'branch',
branch_name
))
self._checkout(branch_name)
def _checkout(self, what):
self._run((
'git', 'checkout',
what
))
def restore_snapshot(self, date):
commit = self._get_commit_from_timestamp(date)
branch = self._last_valid_branch
if commit == self._latest_commit_on_branch(branch):
commit = branch
self._checkout(commit)
def _get_commit_from_timestamp(self, date):
commit = self._get_stdout((
'git', 'rev-list',
'--date=iso',
'-n', '1',
f'--before="{date.isoformat()}"',
self._last_valid_branch
))
if not commit:
commit = self._get_oldest_parent_of_head()
return commit
def _get_oldest_parent_of_head(self):
return self._get_stdout((
'git',
'rev-list',
'--max-parents=0',
'HEAD'
))
@property
def _last_valid_branch(self):
if not self._head_detached:
self.__last_valid_branch = self._branch
return self.__last_valid_branch
def _latest_commit_on_branch(self, branch):
return self._get_stdout((
'git', 'log',
'-n', '1',
'--pretty=format:%H',
branch
))
@property
def all_timelines(self):
return self._branches
@property
def _branches(self):
git_branch_output = self._get_stdout(('git', 'branch'))
regex_pattern = re.compile(r'(?:[^\S\n]|[*])') # matches '*' and non-newline whitespace chars
return re.sub(regex_pattern, '', git_branch_output).splitlines()
@property
def timeline(self):
return self._last_valid_branch
@timeline.setter
def timeline(self, value):
self._checkout(value)
@property
def snapshots(self):
return self._pretty_log_branch()
def _pretty_log_branch(self):
git_log_output = self._get_stdout((
'git', 'log',
'--pretty=%H@%aI'
))
commits = []
for line in git_log_output.splitlines():
commit_hash, timestamp = line.split('@')
commits.append({
'hash': commit_hash,
'timestamp': dateparser.parse(timestamp)
})
return commits

View File

@ -0,0 +1,3 @@
from .terminal_handler import TerminalHandler
from .terminal_commands_handler import TerminalCommandsHandler
from .commands_equal import CommandsEqual

View File

@ -0,0 +1,107 @@
from shlex import split
from re import search
from tfw.internals.lazy import lazy_property
class CommandsEqual:
# pylint: disable=too-many-arguments
"""
This class is useful for comparing executed commands with
excepted commands (i.e. when triggering a state change when
the correct command is executed).
Note that in most cases you should test the changes
caused by the commands instead of just checking command history
(stuff can be done in countless ways and preparing for every
single case is impossible). This should only be used when
testing the changes would be very difficult, like when
explaining stuff with cli tools and such.
This class implicitly converts to bool, use it like
if CommandsEqual(...): ...
It tries detecting differing command parameter orders with similar
semantics and provides fuzzy logic options.
The rationale behind this is that a few false positives
are better than only accepting a single version of a command
(i.e. using ==).
"""
def __init__(
self, command_1, command_2,
fuzzyness=1, begin_similarly=True,
include_patterns=None, exclude_patterns=None
):
"""
:param command_1: Compared command 1
:param command_2: Compared command 2
:param fuzzyness: float between 0 and 1.
the percentage of arguments required to
match between commands to result in True.
i.e 1 means 100% - all arguments need to be
present in both commands, while 0.75
would mean 75% - in case of 4 arguments
1 could differ between the commands.
:param begin_similarly: bool, the first word of the commands
must match
:param include_patterns: list of regex patterns the commands
must include
:param exclude_patterns: list of regex patterns the commands
must exclude
"""
self.command_1 = split(command_1)
self.command_2 = split(command_2)
self.fuzzyness = fuzzyness
self.begin_similarly = begin_similarly
self.include_patterns = include_patterns
self.exclude_patterns = exclude_patterns
def __bool__(self):
if self.begin_similarly:
if not self.beginnings_are_equal:
return False
if self.include_patterns is not None:
if not self.commands_contain_include_patterns:
return False
if self.exclude_patterns is not None:
if not self.commands_contain_no_exclude_patterns:
return False
return self.similarity >= self.fuzzyness
@lazy_property
def beginnings_are_equal(self):
return self.command_1[0] == self.command_2[0]
@lazy_property
def commands_contain_include_patterns(self):
return all((
self.contains_regex_patterns(self.command_1, self.include_patterns),
self.contains_regex_patterns(self.command_2, self.include_patterns)
))
@lazy_property
def commands_contain_no_exclude_patterns(self):
return all((
not self.contains_regex_patterns(self.command_1, self.exclude_patterns),
not self.contains_regex_patterns(self.command_2, self.exclude_patterns)
))
@staticmethod
def contains_regex_patterns(command, regex_parts):
command = ' '.join(command)
for pattern in regex_parts:
if not search(pattern, command):
return False
return True
@lazy_property
def similarity(self):
parts_1 = set(self.command_1)
parts_2 = set(self.command_2)
difference = parts_1 - parts_2
deviance = len(difference) / len(max(parts_1, parts_2))
return 1 - deviance

View File

@ -0,0 +1,97 @@
from re import findall
from re import compile as compileregex
from abc import ABC, abstractmethod
from tfw.internals.inotify import InotifyObserver
class HistoryMonitor(ABC, InotifyObserver):
"""
Abstract class capable of monitoring and parsing a history file such as
bash HISTFILEs. Monitoring means detecting when the file was changed and
notifying subscribers about new content in the file.
This is useful for monitoring CLI sessions.
To specify a custom HistoryMonitor inherit from this class and override the
command pattern property and optionally the sanitize_command method.
See examples below.
"""
def __init__(self, uplink, histfile):
self.histfile = histfile
self.history = []
self._last_length = len(self.history)
self.uplink = uplink
super().__init__(self.histfile)
@property
@abstractmethod
def domain(self):
raise NotImplementedError()
def on_modified(self, event):
self._fetch_history()
if self._last_length < len(self.history):
for command in self.history[self._last_length:]:
self.send_message(command)
def _fetch_history(self):
self._last_length = len(self.history)
with open(self.histfile, 'r') as ifile:
pattern = compileregex(self.command_pattern)
data = ifile.read()
self.history = [
self.sanitize_command(command)
for command in findall(pattern, data)
]
@property
@abstractmethod
def command_pattern(self):
raise NotImplementedError
def sanitize_command(self, command):
# pylint: disable=no-self-use
return command
def send_message(self, command):
self.uplink.send_message({
'key': f'history.{self.domain}',
'value': command
})
class BashMonitor(HistoryMonitor):
"""
HistoryMonitor for monitoring bash CLI sessions.
This requires the following to be set in bash
(note that this is done automatically by TFW):
PROMPT_COMMAND="history -a"
shopt -s cmdhist
shopt -s histappend
unset HISTCONTROL
"""
@property
def domain(self):
return 'bash'
@property
def command_pattern(self):
return r'.+'
def sanitize_command(self, command):
return command.strip()
class GDBMonitor(HistoryMonitor):
"""
HistoryMonitor to monitor GDB sessions.
For this to work "set trace-commands on" must be set in GDB.
"""
@property
def domain(self):
return 'gdb'
@property
def command_pattern(self):
return r'(?<=\n)\+(.+)\n'

View File

@ -0,0 +1,44 @@
import logging
from tornado.web import Application
from terminado import TermSocket, SingleTermManager
LOG = logging.getLogger(__name__)
class TerminadoMiniServer:
def __init__(self, url, port, workdir, shellcmd):
self.port = port
self._term_manager = SingleTermManager(
shell_command=shellcmd,
term_settings={'cwd': workdir}
)
self.application = Application([(
url,
TerminadoMiniServer.ResetterTermSocket,
{'term_manager': self._term_manager}
)])
@property
def term_manager(self):
return self._term_manager
@property
def pty(self):
if self.term_manager.terminal is None:
self.term_manager.get_terminal()
return self.term_manager.terminal.ptyproc
class ResetterTermSocket(TermSocket): # pylint: disable=abstract-method
def check_origin(self, origin):
return True
def on_close(self):
self.term_manager.terminal = None
self.term_manager.get_terminal()
def listen(self):
self.application.listen(self.port)
def stop(self):
self.term_manager.shutdown()

View File

@ -0,0 +1,68 @@
import logging
from abc import ABC
from re import match
from shlex import split
LOG = logging.getLogger(__name__)
class TerminalCommands(ABC):
# pylint: disable=anomalous-backslash-in-string
"""
A class you can use to define hooks for terminal commands. This means that you can
have python code executed when the user enters a specific command to the terminal on
our frontend.
To receive events you need to subscribe TerminalCommand.callback to a HistoryMonitor
instance.
Inherit from this class and define methods which start with "command\_". When the user
executes the command specified after the underscore, your method will be invoked. All
such commands must expect the parameter \*args which will contain the arguments of the
command.
For example to define a method that runs when someone starts vim in the terminal
you have to define a method like: "def command_vim(self, \*args)"
You can also use this class to create new commands similarly.
"""
def __init__(self, bashrc):
self._command_method_regex = r'^command_(.+)$'
self.command_implemetations = self._build_command_to_implementation_dict()
if bashrc is not None:
self._setup_bashrc_aliases(bashrc)
def _build_command_to_implementation_dict(self):
return {
self._parse_command_name(fun): getattr(self, fun)
for fun in dir(self)
if callable(getattr(self, fun))
and self._is_command_implementation(fun)
}
def _setup_bashrc_aliases(self, bashrc):
with open(bashrc, 'a') as ofile:
alias_template = 'type {0} &> /dev/null || alias {0}="{0} &> /dev/null"\n'
for command in self.command_implemetations.keys():
ofile.write(alias_template.format(command))
def _is_command_implementation(self, method_name):
return bool(self._match_command_regex(method_name))
def _parse_command_name(self, method_name):
try:
return self._match_command_regex(method_name).groups()[0]
except AttributeError:
return ''
def _match_command_regex(self, string):
return match(self._command_method_regex, string)
def callback(self, command):
parts = split(command)
command = parts[0]
if command in self.command_implemetations.keys():
try:
self.command_implemetations[command](*parts[1:])
except Exception: # pylint: disable=broad-except
LOG.exception('Command "%s" failed:', command)

View File

@ -0,0 +1,9 @@
from .terminal_commands import TerminalCommands
class TerminalCommandsHandler(TerminalCommands):
keys = ['history.bash']
def handle_event(self, message, _):
command = message['value']
self.callback(command)

View File

@ -0,0 +1,86 @@
import logging
from .history_monitor import BashMonitor
from .terminado_mini_server import TerminadoMiniServer
LOG = logging.getLogger(__name__)
class TerminalHandler:
keys = ['shell']
"""
Event handler responsible for managing terminal sessions for frontend xterm
sessions to connect to. You need to instanciate this in order for frontend
terminals to work.
This EventHandler accepts messages that have a data['command'] key specifying
a command to be executed.
The API of each command is documented in their respective handler.
"""
def __init__(self, *, port, user, workind_directory, histfile):
"""
:param key: key this EventHandler listens to
:param monitor: tfw.components.HistoryMonitor instance to read command history from
"""
self.server_connector = None
self._histfile = histfile
self._historymonitor = None
bash_as_user_cmd = ['sudo', '-u', user, 'bash']
self.terminado_server = TerminadoMiniServer(
'/terminal',
port,
workind_directory,
bash_as_user_cmd
)
self.commands = {
'write': self.write,
'read': self.read
}
self.terminado_server.listen()
def start(self):
self._historymonitor = BashMonitor(self.server_connector, self._histfile)
self._historymonitor.start()
@property
def historymonitor(self):
return self._historymonitor
def handle_event(self, message, _):
try:
data = message['data']
message['data'] = self.commands[data['command']](data)
except KeyError:
LOG.error('IGNORING MESSAGE: Invalid message received: %s', message)
def write(self, data):
"""
Writes a string to the terminal session (on the pty level).
Useful for pre-typing and executing commands for the user.
:param data: TFW message data containing 'value'
(command to be written to the pty)
"""
self.terminado_server.pty.write(data['value'])
return data
def read(self, data):
"""
Reads the history of commands executed.
:param data: TFW message data containing 'count'
(the number of history elements to return)
:return dict: message with list of commands in data['history']
"""
data['count'] = int(data.get('count', 1))
if self.historymonitor:
data['history'] = self.historymonitor.history[-data['count']:]
return data
def cleanup(self):
self.terminado_server.stop()
self.historymonitor.stop()

1
tfw/config/__init__.py Normal file
View File

@ -0,0 +1 @@
from .envvars import TFWENV, TAOENV

5
tfw/config/envvars.py Normal file
View File

@ -0,0 +1,5 @@
from .lazy_environment import LazyEnvironment
TFWENV = LazyEnvironment('TFW_', 'tfwenvtuple').environment
TAOENV = LazyEnvironment('AVATAO_', 'taoenvtuple').environment

View File

@ -0,0 +1,22 @@
from collections import namedtuple
from os import environ
from tfw.internals.lazy import lazy_property
class LazyEnvironment:
def __init__(self, prefix, tuple_name):
self._prefix = prefix
self._tuple_name = tuple_name
@lazy_property
def environment(self):
return self.prefixed_envvars_to_namedtuple()
def prefixed_envvars_to_namedtuple(self):
envvars = {
envvar.replace(self._prefix, '', 1): environ.get(envvar)
for envvar in environ.keys()
if envvar.startswith(self._prefix)
}
return namedtuple(self._tuple_name, envvars)(**envvars)

2
tfw/event_handlers.py Normal file
View File

@ -0,0 +1,2 @@
# pylint: disable=unused-import
from tfw.internals.event_handling import EventHandler, FSMAwareEventHandler

3
tfw/fsm/__init__.py Normal file
View File

@ -0,0 +1,3 @@
from .fsm_base import FSMBase
from .linear_fsm import LinearFSM
from .yaml_fsm import YamlFSM

84
tfw/fsm/fsm_base.py Normal file
View File

@ -0,0 +1,84 @@
import logging
from collections import defaultdict
from datetime import datetime
from transitions import Machine, MachineError
from tfw.internals.callback_mixin import CallbackMixin
LOG = logging.getLogger(__name__)
class FSMBase(Machine, CallbackMixin):
"""
A general FSM base class you can inherit from to track user progress.
See linear_fsm.py for an example use-case.
TFW uses the transitions library for state machines, please refer to their
documentation for more information on creating your own machines:
https://github.com/pytransitions/transitions
"""
states, transitions = [], []
def __init__(self, initial=None, accepted_states=None):
"""
:param initial: which state to begin with, defaults to the last one
:param accepted_states: list of states in which the challenge should be
considered successfully completed
"""
self.accepted_states = accepted_states or [self.states[-1].name]
self.trigger_predicates = defaultdict(list)
self.event_log = []
Machine.__init__(
self,
states=self.states,
transitions=self.transitions,
initial=initial or self.states[0],
send_event=True,
ignore_invalid_triggers=True,
after_state_change='execute_callbacks'
)
def execute_callbacks(self, event_data):
self._execute_callbacks(event_data.kwargs)
def is_solved(self):
return self.state in self.accepted_states # pylint: disable=no-member
def subscribe_predicate(self, trigger, *predicates):
self.trigger_predicates[trigger].extend(predicates)
def unsubscribe_predicate(self, trigger, *predicates):
self.trigger_predicates[trigger] = [
predicate
for predicate in self.trigger_predicates[trigger]
not in predicates
]
def step(self, trigger):
predicate_results = (
predicate()
for predicate in self.trigger_predicates[trigger]
)
if all(predicate_results):
try:
from_state = self.state
self.trigger(trigger)
self.update_event_log(from_state, trigger)
return True
except (AttributeError, MachineError):
LOG.debug('FSM failed to execute nonexistent trigger: "%s"', trigger)
return False
def update_event_log(self, from_state, trigger):
self.event_log.append({
'from_state': from_state,
'to_state': self.state,
'trigger': trigger,
'timestamp': datetime.utcnow()
})
@property
def in_accepted_state(self):
return self.state in self.accepted_states

32
tfw/fsm/linear_fsm.py Normal file
View File

@ -0,0 +1,32 @@
from transitions import State
from .fsm_base import FSMBase
class LinearFSM(FSMBase):
# pylint: disable=anomalous-backslash-in-string
"""
This is a state machine for challenges with linear progression, consisting of
a number of steps specified in the constructor. It automatically sets up 2
actions (triggers) between states as such:
(0) -- step_1 --> (1) -- step_2 --> (2) -- step_3 --> (3) ...
(0) -- step_next --> (1) -- step_next --> (2) -- step_next --> (3) ...
"""
def __init__(self, number_of_steps):
"""
:param number_of_steps: how many states this FSM should have
"""
self.states = [State(name=str(index)) for index in range(number_of_steps)]
self.transitions = []
for state in self.states[:-1]:
self.transitions.append({
'trigger': f'step_{int(state.name)+1}',
'source': state.name,
'dest': str(int(state.name)+1)
})
self.transitions.append({
'trigger': 'step_next',
'source': state.name,
'dest': str(int(state.name)+1)
})
super(LinearFSM, self).__init__()

105
tfw/fsm/yaml_fsm.py Normal file
View File

@ -0,0 +1,105 @@
from subprocess import Popen, run
from functools import partial, singledispatch
from contextlib import suppress
import yaml
import jinja2
from transitions import State
from .fsm_base import FSMBase
class YamlFSM(FSMBase):
"""
This is a state machine capable of building itself from a YAML config file.
"""
def __init__(self, config_file, jinja2_variables=None):
"""
:param config_file: path of the YAML file
:param jinja2_variables: dict containing jinja2 variables
or str with filename of YAML file to
parse and use as dict.
jinja2 support is disabled if this is None
"""
self.config = ConfigParser(config_file, jinja2_variables).config
self.setup_states()
super().__init__() # FSMBase.__init__() requires states
self.setup_transitions()
def setup_states(self):
self.for_config_states_and_transitions_do(self.wrap_callbacks_with_subprocess_call)
self.states = [State(**state) for state in self.config['states']]
def setup_transitions(self):
self.for_config_states_and_transitions_do(self.subscribe_and_remove_predicates)
for transition in self.config['transitions']:
self.add_transition(**transition)
def for_config_states_and_transitions_do(self, what):
for array in ('states', 'transitions'):
for json_obj in self.config[array]:
what(json_obj)
@staticmethod
def wrap_callbacks_with_subprocess_call(json_obj):
topatch = ('on_enter', 'on_exit', 'prepare', 'before', 'after')
for key in json_obj:
if key in topatch:
json_obj[key] = partial(run_command_async, json_obj[key])
def subscribe_and_remove_predicates(self, json_obj):
if 'predicates' in json_obj:
for predicate in json_obj['predicates']:
self.subscribe_predicate(
json_obj['trigger'],
partial(
command_statuscode_is_zero,
predicate
)
)
with suppress(KeyError):
json_obj.pop('predicates')
def run_command_async(command, _):
Popen(command, shell=True)
def command_statuscode_is_zero(command):
return run(command, shell=True).returncode == 0
class ConfigParser:
def __init__(self, config_file, jinja2_variables):
self.read_variables = singledispatch(self._read_variables)
self.read_variables.register(dict, self._read_variables_dict)
self.read_variables.register(str, self._read_variables_str)
self.config = self.parse_config(config_file, jinja2_variables)
def parse_config(self, config_file, jinja2_variables):
config_string = self.read_file(config_file)
if jinja2_variables is not None:
variables = self.read_variables(jinja2_variables)
template = jinja2.Environment(loader=jinja2.BaseLoader).from_string(config_string)
config_string = template.render(**variables)
return yaml.safe_load(config_string)
@staticmethod
def read_file(filename):
with open(filename, 'r') as ifile:
return ifile.read()
@staticmethod
def _read_variables(variables):
raise TypeError(f'Invalid variables type {type(variables)}')
@staticmethod
def _read_variables_str(variables):
with open(variables, 'r') as ifile:
return yaml.safe_load(ifile)
@staticmethod
def _read_variables_dict(variables):
return variables

View File

View File

@ -0,0 +1,35 @@
from functools import partial
from .lazy import lazy_property
class CallbackMixin:
@lazy_property
def _callbacks(self):
# pylint: disable=no-self-use
return []
def subscribe_callback(self, callback, *args, **kwargs):
"""
Subscribe a callable to invoke once an event is triggered.
:param callback: callable to be executed on events
:param args: arguments passed to callable
:param kwargs: kwargs passed to callable
"""
fun = partial(callback, *args, **kwargs)
self._callbacks.append(fun)
def subscribe_callbacks(self, *callbacks):
"""
Subscribe a list of callbacks to incoke once an event is triggered.
:param callbacks: callbacks to be subscribed
"""
for callback in callbacks:
self.subscribe_callback(callback)
def unsubscribe_callback(self, callback):
self._callbacks.remove(callback)
def _execute_callbacks(self, *args, **kwargs):
for callback in self._callbacks:
callback(*args, **kwargs)

107
tfw/internals/crypto.py Normal file
View File

@ -0,0 +1,107 @@
from functools import wraps
from base64 import b64encode, b64decode
from copy import deepcopy
from hashlib import md5
from os import urandom, chmod
from os.path import exists
from stat import S_IRUSR, S_IWUSR, S_IXUSR
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.hashes import SHA256
from cryptography.hazmat.primitives.hmac import HMAC as _HMAC
from cryptography.exceptions import InvalidSignature
from tfw.internals.networking import message_bytes
from tfw.internals.lazy import lazy_property
from tfw.config import TFWENV
def message_checksum(message):
return md5(message_bytes(message)).hexdigest()
def sign_message(key, message):
message.pop('scope', None)
message.pop('signature', None)
signature = message_signature(key, message)
message['signature'] = b64encode(signature).decode()
def message_signature(key, message):
return HMAC(key, message_bytes(message)).signature
def verify_message(key, message):
message.pop('scope', None)
message = deepcopy(message)
try:
signature_b64 = message.pop('signature')
signature = b64decode(signature_b64)
actual_signature = message_signature(key, message)
return signature == actual_signature
except KeyError:
return False
class KeyManager:
def __init__(self):
self.keyfile = TFWENV.AUTH_KEY
if not exists(self.keyfile):
self._init_auth_key()
@lazy_property
def auth_key(self):
with open(self.keyfile, 'rb') as ifile:
return ifile.read()
def _init_auth_key(self):
key = self.generate_key()
with open(self.keyfile, 'wb') as ofile:
ofile.write(key)
self._chmod_700_keyfile()
return key
@staticmethod
def generate_key():
return urandom(32)
def _chmod_700_keyfile(self):
chmod(self.keyfile, S_IRUSR | S_IWUSR | S_IXUSR)
class HMAC:
def __init__(self, key, message):
self.key = key
self.message = message
self._hmac = _HMAC(
key=key,
algorithm=SHA256(),
backend=default_backend()
)
def _reload_if_finalized(f):
# pylint: disable=no-self-argument,not-callable
@wraps(f)
def wrapped(instance, *args, **kwargs):
if getattr(instance, '_finalized', False):
instance.__init__(instance.key, instance.message)
ret_val = f(instance, *args, **kwargs)
setattr(instance, '_finalized', True)
return ret_val
return wrapped
@property
@_reload_if_finalized
def signature(self):
self._hmac.update(self.message)
signature = self._hmac.finalize()
return signature
@_reload_if_finalized
def verify(self, signature):
self._hmac.update(self.message)
try:
self._hmac.verify(signature)
return True
except InvalidSignature:
return False

View File

@ -0,0 +1,3 @@
from .event_handler_factory_base import EventHandlerFactoryBase
from .event_handler import EventHandler
from .fsm_aware_event_handler import FSMAwareEventHandler

View File

@ -0,0 +1,27 @@
class EventHandler:
_instances = set()
def __init__(self, server_connector):
type(self)._instances.add(self)
self.server_connector = server_connector
def start(self):
self.server_connector.register_callback(self._event_callback)
def _event_callback(self, message):
self.handle_event(message, self.server_connector)
def handle_event(self, message, server_connector):
raise NotImplementedError()
@classmethod
def stop_all_instances(cls):
for instance in cls._instances:
instance.stop()
def stop(self):
self.server_connector.close()
self.cleanup()
def cleanup(self):
pass

View File

@ -0,0 +1,68 @@
from contextlib import suppress
from .event_handler import EventHandler
class EventHandlerFactoryBase:
def build(self, handler_stub, *, keys=None, event_handler_type=EventHandler):
builder = EventHandlerBuilder(handler_stub, keys, event_handler_type)
server_connector = self._build_server_connector()
event_handler = builder.build(server_connector)
handler_stub.server_connector = server_connector
with suppress(AttributeError):
handler_stub.start()
event_handler.start()
return event_handler
def _build_server_connector(self):
raise NotImplementedError()
class EventHandlerBuilder:
def __init__(self, event_handler, supplied_keys, event_handler_type):
self._analyzer = HandlerStubAnalyzer(event_handler, supplied_keys)
self._event_handler_type = event_handler_type
def build(self, server_connector):
event_handler = self._event_handler_type(server_connector)
server_connector.subscribe(*self._try_get_keys(event_handler))
event_handler.handle_event = self._analyzer.handle_event
with suppress(AttributeError):
event_handler.cleanup = self._analyzer.cleanup
return event_handler
def _try_get_keys(self, event_handler):
try:
return self._analyzer.keys
except ValueError:
with suppress(AttributeError):
return event_handler.keys
raise
class HandlerStubAnalyzer:
def __init__(self, event_handler, supplied_keys):
self._event_handler = event_handler
self._supplied_keys = supplied_keys
@property
def keys(self):
if self._supplied_keys is None:
try:
return self._event_handler.keys
except AttributeError:
raise ValueError('No keys supplied!')
return self._supplied_keys
@property
def handle_event(self):
try:
return self._event_handler.handle_event
except AttributeError:
if callable(self._event_handler):
return self._event_handler
raise ValueError('Object must implement handle_event or be a callable!')
@property
def cleanup(self):
return self._event_handler.cleanup

View File

@ -0,0 +1,37 @@
import logging
from tfw.internals.crypto import KeyManager, verify_message
LOG = logging.getLogger(__name__)
class FSMAware:
keys = ['fsm_update']
"""
Base class for stuff that has to be aware of the framework FSM.
This is done by processing 'fsm_update' messages.
"""
def __init__(self):
self.fsm_state = None
self.fsm_in_accepted_state = False
self.fsm_event_log = []
self._auth_key = KeyManager().auth_key
def process_message(self, message):
if message['key'] == 'fsm_update':
if verify_message(self._auth_key, message):
self._handle_fsm_update(message)
def _handle_fsm_update(self, message):
try:
new_state = message['current_state']
if self.fsm_state != new_state:
self.handle_fsm_step(message)
self.fsm_state = new_state
self.fsm_in_accepted_state = message['in_accepted_state']
self.fsm_event_log.append(message)
except KeyError:
LOG.error('Invalid fsm_update message received!')
def handle_fsm_step(self, message):
pass

View File

@ -0,0 +1,19 @@
from .event_handler import EventHandler
from .fsm_aware import FSMAware
class FSMAwareEventHandler(EventHandler, FSMAware):
# pylint: disable=abstract-method
"""
Abstract base class for EventHandlers which automatically
keep track of the state of the TFW FSM.
"""
def __init__(self, server_connector):
EventHandler.__init__(self, server_connector)
FSMAware.__init__(self)
def _event_callback(self, message):
self.process_message(message)
def handle_fsm_step(self, message):
self.handle_event(message, self.server_connector)

View File

@ -0,0 +1,190 @@
# pylint: disable=redefined-outer-name,attribute-defined-outside-init
from secrets import token_urlsafe
from random import randint
import pytest
from .event_handler_factory_base import EventHandlerFactoryBase
from .event_handler import EventHandler
class MockEventHandlerFactory(EventHandlerFactoryBase):
def _build_server_connector(self):
return MockServerConnector()
class MockServerConnector:
def __init__(self):
self.keys = []
self._on_message = None
def simulate_message(self, message):
self._on_message(message)
def register_callback(self, callback):
self._on_message = callback
def subscribe(self, *keys):
self.keys.extend(keys)
def unsubscribe(self, *keys):
for key in keys:
self.keys.remove(key)
def send_message(self, message, scope=None):
pass
def close(self):
pass
class MockEventHandlerStub:
def __init__(self):
self.server_connector = None
self.last_message = None
self.cleaned_up = False
self.started = False
def start(self):
self.started = True
def cleanup(self):
self.cleaned_up = True
class MockEventHandler(MockEventHandlerStub):
# pylint: disable=unused-argument
def handle_event(self, message, server_connector):
self.last_message = message
class MockCallable(MockEventHandlerStub):
def __call__(self, message, server_connector):
self.last_message = message
@pytest.fixture
def test_msg():
yield token_urlsafe(randint(16, 64))
@pytest.fixture
def test_keys():
yield [
token_urlsafe(randint(2, 8))
for _ in range(randint(16, 32))
]
def test_build_from_object(test_keys, test_msg):
mock_eh = MockEventHandlerStub()
def handle_event(message, server_connector):
raise RuntimeError(message, server_connector.keys)
mock_eh.handle_event = handle_event
assert not mock_eh.started
eh = MockEventHandlerFactory().build(mock_eh, keys=test_keys)
assert mock_eh.started
assert mock_eh.server_connector is eh.server_connector
with pytest.raises(RuntimeError) as err:
eh.server_connector.simulate_message(test_msg)
msg, keys = err.args
assert msg == test_msg
assert keys == test_keys
assert not mock_eh.cleaned_up
eh.stop()
assert mock_eh.cleaned_up
def test_build_from_object_with_keys(test_keys, test_msg):
mock_eh = MockEventHandler()
mock_eh.keys = test_keys
assert not mock_eh.started
eh = MockEventHandlerFactory().build(mock_eh)
assert mock_eh.server_connector.keys == test_keys
assert eh.server_connector is mock_eh.server_connector
assert mock_eh.started
assert not mock_eh.last_message
eh.server_connector.simulate_message(test_msg)
assert mock_eh.last_message == test_msg
assert not mock_eh.cleaned_up
EventHandler.stop_all_instances()
assert mock_eh.cleaned_up
def test_build_from_simple_object(test_keys, test_msg):
class SimpleMockEventHandler:
# pylint: disable=no-self-use
def handle_event(self, message, server_connector):
raise RuntimeError(message, server_connector)
mock_eh = SimpleMockEventHandler()
eh = MockEventHandlerFactory().build(mock_eh, keys=test_keys)
with pytest.raises(RuntimeError) as err:
eh.server_connector.simulate_message(test_msg)
msg, keys = err.args
assert msg == test_msg
assert keys == test_keys
def test_build_from_callable(test_keys, test_msg):
mock_eh = MockCallable()
assert not mock_eh.started
eh = MockEventHandlerFactory().build(mock_eh, keys=test_keys)
assert mock_eh.started
assert mock_eh.server_connector is eh.server_connector
assert eh.server_connector.keys == test_keys
assert not mock_eh.last_message
eh.server_connector.simulate_message(test_msg)
assert mock_eh.last_message == test_msg
assert not mock_eh.cleaned_up
eh.stop()
assert mock_eh.cleaned_up
def test_build_from_function(test_keys, test_msg):
def some_function(message, server_connector):
raise RuntimeError(message, server_connector.keys)
eh = MockEventHandlerFactory().build(some_function, keys=test_keys)
assert eh.server_connector.keys == test_keys
with pytest.raises(RuntimeError) as err:
eh.server_connector.simulate_message(test_msg)
msg, keys = err.args
assert msg == test_msg
assert keys == test_keys
def test_build_from_lambda(test_keys, test_msg):
def assert_messages_equal(msg):
assert msg == test_msg
fun = lambda msg, sc: assert_messages_equal(msg)
eh = MockEventHandlerFactory().build(fun, keys=test_keys)
eh.server_connector.simulate_message(test_msg)
def test_build_raises_if_no_key(test_keys):
eh = MockEventHandler()
with pytest.raises(ValueError):
MockEventHandlerFactory().build(eh)
def handle_event(*_):
pass
with pytest.raises(ValueError):
MockEventHandlerFactory().build(handle_event)
with pytest.raises(ValueError):
MockEventHandlerFactory().build(lambda msg, sc: None)
WithKeysEventHandler = EventHandler
WithKeysEventHandler.keys = test_keys
MockEventHandlerFactory().build(eh, event_handler_type=WithKeysEventHandler)
eh.keys = test_keys
MockEventHandlerFactory().build(eh)

View File

@ -0,0 +1,6 @@
from .inotify import InotifyObserver
from .inotify import (
InotifyFileCreatedEvent, InotifyFileModifiedEvent, InotifyFileMovedEvent,
InotifyFileDeletedEvent, InotifyDirCreatedEvent, InotifyDirModifiedEvent,
InotifyDirMovedEvent, InotifyDirDeletedEvent
)

View File

@ -0,0 +1,189 @@
# pylint: disable=too-few-public-methods
from typing import Iterable
from time import time
from os.path import abspath, dirname, isdir
from watchdog.observers import Observer
from watchdog.events import FileSystemMovedEvent, PatternMatchingEventHandler
from watchdog.events import (
FileCreatedEvent, FileModifiedEvent, FileMovedEvent, FileDeletedEvent,
DirCreatedEvent, DirModifiedEvent, DirMovedEvent, DirDeletedEvent
)
class InotifyEvent:
def __init__(self, src_path):
self.date = time()
self.src_path = src_path
def __str__(self):
return self.__repr__()
def __repr__(self):
return f'{self.__class__.__name__}({self.src_path})'
class InotifyMovedEvent(InotifyEvent):
def __init__(self, src_path, dest_path):
self.dest_path = dest_path
super().__init__(src_path)
def __repr__(self):
return f'{self.__class__.__name__}({self.src_path}, {self.dest_path})'
class InotifyFileCreatedEvent(InotifyEvent):
pass
class InotifyFileModifiedEvent(InotifyEvent):
pass
class InotifyFileMovedEvent(InotifyMovedEvent):
pass
class InotifyFileDeletedEvent(InotifyEvent):
pass
class InotifyDirCreatedEvent(InotifyEvent):
pass
class InotifyDirModifiedEvent(InotifyEvent):
pass
class InotifyDirMovedEvent(InotifyMovedEvent):
pass
class InotifyDirDeletedEvent(InotifyEvent):
pass
class InotifyObserver:
def __init__(self, path, patterns=None, exclude=None, recursive=False):
self._files = []
self._paths = path
self._patterns = patterns or []
self._exclude = exclude
self._recursive = recursive
self._observer = Observer()
self._reset()
def _reset(self):
if isinstance(self._paths, str):
self._paths = [self._paths]
if isinstance(self._paths, Iterable):
self._extract_files_from_paths()
else:
raise ValueError('Expected one or more string paths.')
patterns = self._files+self.patterns
handler = PatternMatchingEventHandler(patterns if patterns else None, self.exclude)
handler.on_any_event = self._dispatch_event
self._observer.unschedule_all()
for path in self.paths:
self._observer.schedule(handler, path, self._recursive)
def _extract_files_from_paths(self):
files, paths = [], []
for path in self._paths:
path = abspath(path)
if isdir(path):
paths.append(path)
else:
paths.append(dirname(path))
files.append(path)
self._files, self._paths = files, paths
@property
def paths(self):
return self._paths
@paths.setter
def paths(self, paths):
self._paths = paths
self._reset()
@property
def patterns(self):
return self._patterns
@patterns.setter
def patterns(self, patterns):
self._patterns = patterns or []
self._reset()
@property
def exclude(self):
return self._exclude
@exclude.setter
def exclude(self, exclude):
self._exclude = exclude
self._reset()
def start(self):
self._observer.start()
def stop(self):
self._observer.stop()
self._observer.join()
def _dispatch_event(self, event):
event_to_action = {
InotifyFileCreatedEvent : self.on_created,
InotifyFileModifiedEvent : self.on_modified,
InotifyFileMovedEvent : self.on_moved,
InotifyFileDeletedEvent : self.on_deleted,
InotifyDirCreatedEvent : self.on_created,
InotifyDirModifiedEvent : self.on_modified,
InotifyDirMovedEvent : self.on_moved,
InotifyDirDeletedEvent : self.on_deleted
}
event = self._transform_event(event)
self.on_any_event(event)
event_to_action[type(event)](event)
@staticmethod
def _transform_event(event):
watchdog_to_inotify = {
FileCreatedEvent : InotifyFileCreatedEvent,
FileModifiedEvent : InotifyFileModifiedEvent,
FileMovedEvent : InotifyFileMovedEvent,
FileDeletedEvent : InotifyFileDeletedEvent,
DirCreatedEvent : InotifyDirCreatedEvent,
DirModifiedEvent : InotifyDirModifiedEvent,
DirMovedEvent : InotifyDirMovedEvent,
DirDeletedEvent : InotifyDirDeletedEvent
}
try:
cls = watchdog_to_inotify[type(event)]
except KeyError:
raise NameError('Watchdog API returned an unknown event.')
if isinstance(event, FileSystemMovedEvent):
return cls(event.src_path, event.dest_path)
return cls(event.src_path)
def on_any_event(self, event):
pass
def on_created(self, event):
pass
def on_modified(self, event):
pass
def on_moved(self, event):
pass
def on_deleted(self, event):
pass

View File

@ -0,0 +1,179 @@
# pylint: disable=redefined-outer-name
from queue import Empty, Queue
from secrets import token_urlsafe
from pathlib import Path
from shutil import rmtree
from os.path import join
from os import mkdir, remove, rename
from tempfile import TemporaryDirectory
from contextlib import suppress
import watchdog
import pytest
from .inotify import InotifyObserver
from .inotify import (
InotifyFileCreatedEvent, InotifyFileModifiedEvent, InotifyFileMovedEvent,
InotifyFileDeletedEvent, InotifyDirCreatedEvent, InotifyDirModifiedEvent,
InotifyDirMovedEvent, InotifyDirDeletedEvent
)
with suppress(AttributeError):
watchdog.observers.inotify_buffer.InotifyBuffer.delay = 0
class InotifyContext:
def __init__(self, workdir, subdir, subfile, observer):
self.missing_events = 0
self.workdir = workdir
self.subdir = subdir
self.subfile = subfile
self.observer = observer
self.event_to_queue = {
InotifyFileCreatedEvent : self.observer.create_queue,
InotifyFileModifiedEvent : self.observer.modify_queue,
InotifyFileMovedEvent : self.observer.move_queue,
InotifyFileDeletedEvent : self.observer.delete_queue,
InotifyDirCreatedEvent : self.observer.create_queue,
InotifyDirModifiedEvent : self.observer.modify_queue,
InotifyDirMovedEvent : self.observer.move_queue,
InotifyDirDeletedEvent : self.observer.delete_queue
}
def create_random_file(self, dirname, extension):
filename = self.join(f'{dirname}/{generate_name()}{extension}')
Path(filename).touch()
return filename
def create_random_folder(self, basepath):
dirname = self.join(f'{basepath}/{generate_name()}')
mkdir(dirname)
return dirname
def join(self, path):
return join(self.workdir, path)
def check_event(self, event_type, path):
self.missing_events += 1
event = self.event_to_queue[event_type].get(timeout=0.1)
assert isinstance(event, event_type)
assert event.src_path == path
return event
def check_empty(self, event_type):
with pytest.raises(Empty):
self.event_to_queue[event_type].get(timeout=0.1)
def check_any(self):
attrs = self.observer.__dict__.values()
total = sum([q.qsize() for q in attrs if isinstance(q, Queue)])
return total+self.missing_events == len(self.observer.any_list)
class InotifyTestObserver(InotifyObserver):
def __init__(self, paths, patterns=None, exclude=None, recursive=False):
self.any_list = []
self.create_queue, self.modify_queue, self.move_queue, self.delete_queue = [Queue() for _ in range(4)]
super().__init__(paths, patterns, exclude, recursive)
def on_any_event(self, event):
self.any_list.append(event)
def on_created(self, event):
self.create_queue.put(event)
def on_modified(self, event):
self.modify_queue.put(event)
def on_moved(self, event):
self.move_queue.put(event)
def on_deleted(self, event):
self.delete_queue.put(event)
def generate_name():
return token_urlsafe(16)
@pytest.fixture()
def context():
with TemporaryDirectory() as workdir:
subdir = join(workdir, generate_name())
subfile = join(subdir, generate_name()+'.txt')
mkdir(subdir)
Path(subfile).touch()
monitor = InotifyTestObserver(workdir, recursive=True)
monitor.start()
yield InotifyContext(workdir, subdir, subfile, monitor)
def test_create(context):
newfile = context.create_random_file(context.workdir, '.txt')
context.check_event(InotifyFileCreatedEvent, newfile)
newdir = context.create_random_folder(context.workdir)
context.check_event(InotifyDirCreatedEvent, newdir)
assert context.check_any()
def test_modify(context):
with open(context.subfile, 'wb', buffering=0) as ofile:
ofile.write(b'text')
context.check_event(InotifyFileModifiedEvent, context.subfile)
while True:
try:
context.observer.modify_queue.get(timeout=0.1)
context.missing_events += 1
except Empty:
break
rename(context.subfile, context.subfile+'_new')
context.check_event(InotifyDirModifiedEvent, context.subdir)
assert context.check_any()
def test_move(context):
rename(context.subdir, context.subdir+'_new')
context.check_event(InotifyDirMovedEvent, context.subdir)
context.check_event(InotifyFileMovedEvent, context.subfile)
assert context.check_any()
def test_delete(context):
rmtree(context.subdir)
context.check_event(InotifyFileDeletedEvent, context.subfile)
context.check_event(InotifyDirDeletedEvent, context.subdir)
assert context.check_any()
def test_paths(context):
context.observer.paths = context.subdir
newdir = context.create_random_folder(context.workdir)
newfile = context.create_random_file(context.subdir, '.txt')
context.check_event(InotifyDirModifiedEvent, context.subdir)
context.check_event(InotifyFileCreatedEvent, newfile)
context.observer.paths = [newdir, newfile]
remove(newfile)
context.check_event(InotifyFileDeletedEvent, newfile)
assert context.check_any()
context.observer.paths = context.workdir
def test_patterns(context):
context.observer.patterns = ['*.txt']
context.create_random_file(context.subdir, '.bin')
newfile = context.create_random_file(context.subdir, '.txt')
context.check_event(InotifyFileCreatedEvent, newfile)
context.check_empty(InotifyFileCreatedEvent)
assert context.check_any()
context.observer.patterns = None
def test_exclude(context):
context.observer.exclude = ['*.txt']
context.create_random_file(context.subdir, '.txt')
newfile = context.create_random_file(context.subdir, '.bin')
context.check_event(InotifyFileCreatedEvent, newfile)
context.check_empty(InotifyFileCreatedEvent)
assert context.check_any()
context.observer.exclude = None
def test_stress(context):
newfile = []
for i in range(1024):
newfile.append(context.create_random_file(context.subdir, '.txt'))
for i in range(1024):
context.check_event(InotifyFileCreatedEvent, newfile[i])
assert context.check_any()

27
tfw/internals/lazy.py Normal file
View File

@ -0,0 +1,27 @@
from functools import update_wrapper, wraps
class lazy_property:
"""
Decorator that replaces a function with the value
it calculates on the first call.
"""
def __init__(self, func):
self.func = func
update_wrapper(self, func)
def __get__(self, instance, owner):
if instance is None:
return self # avoids potential __new__ TypeError
value = self.func(instance)
setattr(instance, self.func.__name__, value)
return value
def lazy_factory(fun):
class wrapper:
@wraps(fun)
@lazy_property
def instance(self): # pylint: disable=no-self-use
return fun()
return wrapper()

View File

@ -0,0 +1,4 @@
from .serialization import serialize_tfw_msg, deserialize_tfw_msg, with_deserialize_tfw_msg, message_bytes
from .server_connector import ServerUplinkConnector, ServerDownlinkConnector, ServerConnector
from .event_handler_connector import EventHandlerConnector
from .scope import Scope

View File

@ -0,0 +1,48 @@
import logging
import zmq
from zmq.eventloop.zmqstream import ZMQStream
from .serialization import serialize_tfw_msg, with_deserialize_tfw_msg
LOG = logging.getLogger(__name__)
class EventHandlerDownlinkConnector:
def __init__(self, bind_addr):
self._zmq_pull_socket = zmq.Context.instance().socket(zmq.PULL)
self._zmq_pull_socket.setsockopt(zmq.RCVHWM, 0)
self._zmq_pull_stream = ZMQStream(self._zmq_pull_socket)
self._zmq_pull_socket.bind(bind_addr)
LOG.debug('Pull socket bound to %s', bind_addr)
def register_callback(self, callback):
callback = with_deserialize_tfw_msg(callback)
self._zmq_pull_stream.on_recv(callback)
def close(self):
self._zmq_pull_stream.close()
class EventHandlerUplinkConnector:
def __init__(self, bind_addr):
self._zmq_pub_socket = zmq.Context.instance().socket(zmq.PUB)
self._zmq_pub_socket.setsockopt(zmq.SNDHWM, 0)
self._zmq_pub_socket.bind(bind_addr)
LOG.debug('Pub socket bound to %s', bind_addr)
def send_message(self, message: dict):
self._zmq_pub_socket.send_multipart(serialize_tfw_msg(message))
def close(self):
self._zmq_pub_socket.close()
class EventHandlerConnector(EventHandlerDownlinkConnector, EventHandlerUplinkConnector):
def __init__(self, downlink_bind_addr, uplink_bind_addr):
EventHandlerDownlinkConnector.__init__(self, downlink_bind_addr)
EventHandlerUplinkConnector.__init__(self, uplink_bind_addr)
def close(self):
EventHandlerDownlinkConnector.close(self)
EventHandlerUplinkConnector.close(self)

View File

@ -0,0 +1,7 @@
from enum import Enum
class Scope(Enum):
ZMQ = 'zmq'
WEBSOCKET = 'websocket'
BROADCAST = 'broadcast'

View File

@ -0,0 +1,104 @@
"""
TFW JSON message format
message:
{
"key": string, # addressing
"data": {...}, # payload
"trigger": string # FSM trigger
}
ZeroMQ's sub-pub sockets use enveloped messages
(http://zguide.zeromq.org/page:all#Pub-Sub-Message-Envelopes)
and TFW also uses them internally. This means that on ZMQ sockets
we always send the messages key separately and then the actual
message (which contains the key as well) like so:
socket.send_multipart([message['key'], message])
The purpose of this module is abstracting away this low level behaviour.
"""
import json
from functools import wraps
def serialize_tfw_msg(message):
"""
Create TFW multipart data from message dict
"""
return _serialize_all(message['key'], message)
def with_deserialize_tfw_msg(fun):
@wraps(fun)
def wrapper(message_parts):
message = deserialize_tfw_msg(*message_parts)
return fun(message)
return wrapper
def deserialize_tfw_msg(*args):
"""
Return message from TFW multipart data
"""
return _deserialize_all(*args)[1]
def _serialize_all(*args):
return tuple(
_serialize_single(arg)
for arg in args
)
def _deserialize_all(*args):
return tuple(
_deserialize_single(arg)
for arg in args
)
def _serialize_single(data):
"""
Return input as bytes
(serialize input if it is JSON)
"""
if not isinstance(data, str):
data = message_bytes(data)
return _encode_if_needed(data)
def message_bytes(message):
return json.dumps(message, sort_keys=True).encode()
def _deserialize_single(data):
"""
Try parsing input as JSON, return it as
string if parsing fails.
"""
try:
return json.loads(data)
except ValueError:
return _decode_if_needed(data)
def _encode_if_needed(value):
"""
Return input as bytes
(encode if input is string)
"""
if isinstance(value, str):
value = value.encode('utf-8')
return value
def _decode_if_needed(value):
"""
Return input as string
(decode if input is bytes)
"""
if isinstance(value, (bytes, bytearray)):
value = value.decode('utf-8')
return value

View File

@ -0,0 +1,65 @@
import logging
import zmq
from zmq.eventloop.zmqstream import ZMQStream
from .scope import Scope
from .serialization import serialize_tfw_msg, with_deserialize_tfw_msg
LOG = logging.getLogger(__name__)
class ServerDownlinkConnector:
def __init__(self, connect_addr):
self.keys = []
self._on_recv_callback = None
self._zmq_sub_socket = zmq.Context.instance().socket(zmq.SUB)
self._zmq_sub_socket.setsockopt(zmq.RCVHWM, 0)
self._zmq_sub_socket.connect(connect_addr)
self._zmq_sub_stream = ZMQStream(self._zmq_sub_socket)
def subscribe(self, *keys):
for key in keys:
self._zmq_sub_socket.setsockopt_string(zmq.SUBSCRIBE, key)
self.keys.append(key)
def unsubscribe(self, *keys):
for key in keys:
self._zmq_sub_socket.setsockopt_string(zmq.UNSUBSCRIBE, key)
self.keys.remove(key)
def register_callback(self, callback):
self._on_recv_callback = callback
self._zmq_sub_stream.on_recv(with_deserialize_tfw_msg(self._on_recv))
def _on_recv(self, message):
key = message['key']
if key in self.keys or '' in self.keys:
self._on_recv_callback(message)
def close(self):
self._zmq_sub_stream.close()
class ServerUplinkConnector:
def __init__(self, connect_addr):
self._zmq_push_socket = zmq.Context.instance().socket(zmq.PUSH)
self._zmq_push_socket.setsockopt(zmq.SNDHWM, 0)
self._zmq_push_socket.connect(connect_addr)
def send_message(self, message, scope=Scope.ZMQ):
message['scope'] = scope.value
self._zmq_push_socket.send_multipart(serialize_tfw_msg(message))
def close(self):
self._zmq_push_socket.close()
class ServerConnector(ServerDownlinkConnector, ServerUplinkConnector):
def __init__(self, downlink_connect_addr, uplink_connect_addr):
ServerDownlinkConnector.__init__(self, downlink_connect_addr)
ServerUplinkConnector.__init__(self, uplink_connect_addr)
def close(self):
ServerDownlinkConnector.close(self)
ServerUplinkConnector.close(self)

View File

@ -0,0 +1 @@
from .zmq_websocket_router import ZMQWebSocketRouter

View File

@ -0,0 +1,69 @@
import json
import logging
from tornado.websocket import WebSocketHandler
from tfw.internals.networking import Scope
LOG = logging.getLogger(__name__)
class ZMQWebSocketRouter(WebSocketHandler):
# pylint: disable=abstract-method,attribute-defined-outside-init
instances = set()
def initialize(self, **kwargs):
self.event_handler_connector = kwargs['event_handler_connector']
self.tfw_router = TFWRouter(self.send_to_zmq, self.send_to_websockets)
def send_to_zmq(self, message):
self.event_handler_connector.send_message(message)
@classmethod
def send_to_websockets(cls, message):
for instance in cls.instances:
instance.write_message(message)
def prepare(self):
type(self).instances.add(self)
def on_close(self):
type(self).instances.remove(self)
def open(self, *args, **kwargs):
LOG.debug('WebSocket connection initiated!')
self.event_handler_connector.register_callback(self.zmq_callback)
def zmq_callback(self, message):
LOG.debug('Received on ZMQ pull socket: %s', message)
self.tfw_router.route(message)
def on_message(self, message):
message = json.loads(message)
LOG.debug('Received on WebSocket: %s', message)
self.tfw_router.route(message)
# much secure, very cors, wow
def check_origin(self, origin):
return True
class TFWRouter:
def __init__(self, send_to_zmq, send_to_websockets):
self.send_to_zmq = send_to_zmq
self.send_to_websockets = send_to_websockets
def route(self, message):
scope = Scope(message.pop('scope', 'zmq'))
routing_table = {
Scope.ZMQ: self.send_to_zmq,
Scope.WEBSOCKET: self.send_to_websockets,
Scope.BROADCAST: self.broadcast
}
action = routing_table[scope]
action(message)
def broadcast(self, message):
self.send_to_zmq(message)
self.send_to_websockets(message)

132
tfw/logging.py Normal file
View File

@ -0,0 +1,132 @@
# pylint: disable=bad-whitespace
from datetime import datetime
from typing import TextIO, Union
from dataclasses import dataclass
from traceback import format_exception
from logging import DEBUG, getLogger, Handler, Formatter, Filter
class Color:
RED = '\033[31m'
GREEN = '\033[32m'
YELLOW = '\033[33m'
BLUE = '\033[34m'
CYAN = '\033[36m'
WHITE = '\033[37m'
RESET = '\033[0m'
@dataclass
class Log:
stream: Union[str, TextIO]
formatter: Formatter
class Logger:
def __init__(self, logs, level=DEBUG):
self.root_logger = getLogger()
self.old_level = self.root_logger.level
self.new_level = level
self.handlers = []
for log in logs:
handler = LogHandler(log.stream)
handler.setFormatter(log.formatter)
self.handlers.append(handler)
def start(self):
self.root_logger.setLevel(self.new_level)
for handler in self.handlers:
self.root_logger.addHandler(handler)
def stop(self):
self.root_logger.setLevel(self.old_level)
for handler in self.handlers:
handler.close()
self.root_logger.removeHandler(handler)
class LogHandler(Handler):
def __init__(self, stream):
if isinstance(stream, str):
self.stream = open(stream, 'a+')
self.close_stream = True
else:
self.stream = stream
self.close_stream = False
super().__init__()
def emit(self, record):
entry = self.format(record)
self.stream.write(entry+'\n')
self.stream.flush()
def close(self):
if self.close_stream:
self.stream.close()
class LogFormatter(Formatter):
severity_to_color = {
'CRITICAL' : Color.RED,
'ERROR' : Color.RED,
'WARNING' : Color.YELLOW,
'INFO' : Color.GREEN,
'DEBUG' : Color.BLUE,
'NOTSET' : Color.CYAN
}
def __init__(self, limit):
self.limit = limit
super().__init__()
def format(self, record):
time = datetime.utcfromtimestamp(record.created).strftime('%H:%M:%S')
if record.args:
tuple_args = (record.args,) if isinstance(record.args, dict) else record.args
clean_args = tuple((self.trim(arg) for arg in tuple_args))
message = record.msg % clean_args
else:
message = record.msg
trace = '\n'+''.join(format_exception(*record.exc_info)) if record.exc_info else ''
return (f'[{Color.WHITE}{time}{Color.RESET}|>'
f'{self.severity_to_color[record.levelname]}{record.module}:'
f'{record.levelname.lower()}{Color.RESET}] {message}{trace}')
def trim(self, value):
if isinstance(value, dict):
return {k: self.trim(v) for k, v in value.items()}
if isinstance(value, str):
value_str = str(value)
return value_str if len(value_str) <= self.limit else f'{value_str[:self.limit]}...'
return value
class VerboseLogFormatter(Formatter):
def format(self, record): # pylint: disable=no-self-use
date = datetime.utcfromtimestamp(record.created).strftime('%H:%M:%S')
if record.args:
message = record.msg % record.args
else:
message = record.msg
trace = '\n'+''.join(format_exception(*record.exc_info)) if record.exc_info else ''
return (f'[{date}|>{record.module}:{record.levelname.lower()}] '
f'{message}{trace}')
class WhitelistFilter(Filter):
def __init__(self, names):
self.names = names
super().__init__()
def filter(self, record):
return record.module in self.names
class BlacklistFilter(Filter):
def __init__(self, names):
self.names = names
super().__init__()
def filter(self, record):
return record.module not in self.names

4
tfw/main/__init__.py Normal file
View File

@ -0,0 +1,4 @@
from .tfw_connector import TFWUplinkConnector, TFWConnector
from .event_handler_factory import EventHandlerFactory
from .signal_handling import setup_signal_handlers
from .tfw_server import TFWServer

View File

@ -0,0 +1,8 @@
from tfw.internals.event_handling import EventHandlerFactoryBase
from .tfw_connector import TFWConnector
class EventHandlerFactory(EventHandlerFactoryBase):
def _build_server_connector(self):
return TFWConnector()

View File

@ -0,0 +1,11 @@
from signal import signal, SIGTERM, SIGINT
from tfw.internals.event_handling import EventHandler
def setup_signal_handlers():
def stop(*_):
EventHandler.stop_all_instances()
exit(0)
signal(SIGTERM, stop)
signal(SIGINT, stop)

25
tfw/main/tfw_connector.py Normal file
View File

@ -0,0 +1,25 @@
from tfw.internals.networking import ServerConnector, ServerUplinkConnector
from tfw.config import TFWENV
class ConnAddrMixin:
@property
def uplink_conn_addr(self):
return f'tcp://localhost:{TFWENV.PULL_PORT}'
@property
def downlink_conn_addr(self):
return f'tcp://localhost:{TFWENV.PUB_PORT}'
class TFWUplinkConnector(ServerUplinkConnector, ConnAddrMixin):
def __init__(self):
super().__init__(self.uplink_conn_addr)
class TFWConnector(ServerConnector, ConnAddrMixin):
def __init__(self):
super().__init__(
self.downlink_conn_addr,
self.uplink_conn_addr
)

31
tfw/main/tfw_server.py Normal file
View File

@ -0,0 +1,31 @@
import logging
from tornado.web import Application
from tfw.internals.networking import EventHandlerConnector
from tfw.internals.server import ZMQWebSocketRouter
from tfw.config import TFWENV
LOG = logging.getLogger(__name__)
class TFWServer:
"""
This class handles the proxying of messages between the frontend and event handers.
It proxies messages from the "/ws" route to all event handlers subscribed to a ZMQ
SUB socket.
"""
def __init__(self):
self._event_handler_connector = EventHandlerConnector(
downlink_bind_addr=f'tcp://*:{TFWENV.PULL_PORT}',
uplink_bind_addr=f'tcp://*:{TFWENV.PUB_PORT}'
)
self.application = Application([(
r'/ws', ZMQWebSocketRouter, {
'event_handler_connector': self._event_handler_connector,
}
)])
def listen(self):
self.application.listen(TFWENV.WEB_PORT)