diff --git a/lib/tfw/yaml_fsm.py b/lib/tfw/yaml_fsm.py index b887a9c..82c955e 100644 --- a/lib/tfw/yaml_fsm.py +++ b/lib/tfw/yaml_fsm.py @@ -1,25 +1,21 @@ from subprocess import Popen, run -from functools import partial +from functools import partial, singledispatch from contextlib import suppress import yaml +import jinja2 from transitions import State from tfw import FSMBase class YamlFSM(FSMBase): - def __init__(self, config_file): - self.config = self.parse_config(config_file) + def __init__(self, config_file, jinja2_variables=None): + self.config = ConfigParser(config_file, jinja2_variables).config self.setup_states() super().__init__() # FSMBase.__init__() requires states self.setup_transitions() - @staticmethod - def parse_config(config_file): - with open(config_file, 'r') as ifile: - return yaml.safe_load(ifile) - def setup_states(self): self.for_config_states_and_transitions_do(self.wrap_callbacks_with_subprocess_call) self.states = [State(**state) for state in self.config['states']] @@ -62,3 +58,40 @@ def run_command_async(command, event): def command_statuscode_is_zero(command): return run(command, shell=True).returncode == 0 + + +class ConfigParser: + def __init__(self, config_file, jinja2_variables): + self.read_variables = singledispatch(self.read_variables) + self.read_variables.register(dict, self._read_variables_dict) + self.read_variables.register(str, self._read_variables_str) + + self.config = self.parse_config(config_file, jinja2_variables) + + def parse_config(self, config_file, jinja2_variables): + config_string = self.read_file(config_file) + if jinja2_variables is not None: + variables = self.read_variables(jinja2_variables) + template = jinja2.Environment(loader=jinja2.BaseLoader).from_string(config_string) + config_string = template.render(**variables) + return yaml.safe_load(config_string) + + @staticmethod + def read_file(filename): + with open(filename, 'r') as ifile: + return ifile.read() + + @staticmethod + def read_variables(variables): + raise TypeError(f'Invalid variables type {type(variables)}') + + @staticmethod + def _read_variables_str(variables): + if isinstance(variables, str): + with open(variables, 'r') as ifile: + return yaml.safe_load(ifile) + + @staticmethod + def _read_variables_dict(variables): + return variables + diff --git a/requirements.txt b/requirements.txt index a324ff0..4d3a900 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ transitions==0.6.4 terminado==0.8.1 watchdog==0.8.3 PyYAML==3.12 +Jinja2==2.10