import logging from os.path import join as joinpath from os.path import basename from os import makedirs from datetime import datetime from dateutil import parser as dateparser from tfw.components.snapshot_provider import SnapshotProvider from tfw.config import TFWENV from tfw.networking import Scope from .event_handler import EventHandler LOG = logging.getLogger(__name__) class DirectorySnapshottingEventHandler(EventHandler): def __init__(self, key, directories, exclude_unix_patterns=None): super().__init__(key, scope=Scope.WEBSOCKET) self.snapshot_providers = {} self._exclude_unix_patterns = exclude_unix_patterns self.init_snapshot_providers(directories) self.command_handlers = { 'take_snapshot': self.handle_take_snapshot, 'restore_snapshot': self.handle_restore_snapshot, 'exclude': self.handle_exclude } def init_snapshot_providers(self, directories): for index, directory in enumerate(directories): git_dir = self.init_git_dir(index, directory) self.snapshot_providers[directory] = SnapshotProvider( directory, git_dir, self._exclude_unix_patterns ) @staticmethod def init_git_dir(index, directory): git_dir = joinpath( TFWENV.SNAPSHOTS_DIR, f'{basename(directory)}-{index}' ) makedirs(git_dir, exist_ok=True) return git_dir def handle_event(self, message): try: data = message['data'] message['data'] = self.command_handlers[data['command']](data) self.send_message(message) except KeyError: LOG.error('IGNORING MESSAGE: Invalid message received: %s', message) def handle_take_snapshot(self, data): LOG.debug('Taking snapshots of directories %s', self.snapshot_providers.keys()) for provider in self.snapshot_providers.values(): provider.take_snapshot() return data def handle_restore_snapshot(self, data): date = dateparser.parse( data.get( 'value', datetime.now().isoformat() ) ) LOG.debug( 'Restoring snapshots (@ %s) of directories %s', date, self.snapshot_providers.keys() ) for provider in self.snapshot_providers.values(): provider.restore_snapshot(date) return data def handle_exclude(self, data): exclude_unix_patterns = data['value'] if not isinstance(exclude_unix_patterns, list): raise KeyError for provider in self.snapshot_providers.values(): provider.exclude = exclude_unix_patterns return data