import logging from os.path import join as joinpath from os.path import basename from os import makedirs from datetime import datetime from dateutil import parser as dateparser from tfw.internals.networking import Scope from .snapshot_provider import SnapshotProvider LOG = logging.getLogger(__name__) class SnapshotHandler: keys = ['snapshot'] def __init__(self, *, directories, snapshots_dir, exclude_unix_patterns=None): self._snapshots_dir = snapshots_dir self.snapshot_providers = {} self._exclude_unix_patterns = exclude_unix_patterns self.init_snapshot_providers(directories) self.command_handlers = { 'take_snapshot': self.handle_take_snapshot, 'restore_snapshot': self.handle_restore_snapshot, 'exclude': self.handle_exclude } def init_snapshot_providers(self, directories): for index, directory in enumerate(directories): git_dir = self.init_git_dir(index, directory) self.snapshot_providers[directory] = SnapshotProvider( directory, git_dir, self._exclude_unix_patterns ) def init_git_dir(self, index, directory): git_dir = joinpath( self._snapshots_dir, f'{basename(directory)}-{index}' ) makedirs(git_dir, exist_ok=True) return git_dir def handle_event(self, message, server_connector): try: data = message['data'] message['data'] = self.command_handlers[data['command']](data) server_connector.send_message(message, scope=Scope.WEBSOCKET) except KeyError: LOG.error('IGNORING MESSAGE: Invalid message received: %s', message) def handle_take_snapshot(self, data): LOG.debug('Taking snapshots of directories %s', self.snapshot_providers.keys()) for provider in self.snapshot_providers.values(): provider.take_snapshot() return data def handle_restore_snapshot(self, data): date = dateparser.parse( data.get( 'value', datetime.now().isoformat() ) ) LOG.debug( 'Restoring snapshots (@ %s) of directories %s', date, self.snapshot_providers.keys() ) for provider in self.snapshot_providers.values(): provider.restore_snapshot(date) return data def handle_exclude(self, data): exclude_unix_patterns = data['value'] if not isinstance(exclude_unix_patterns, list): raise KeyError for provider in self.snapshot_providers.values(): provider.exclude = exclude_unix_patterns return data