diff --git a/tfw/internals/crypto/__init__.py b/tfw/internals/crypto/__init__.py new file mode 100644 index 0000000..0ff122a --- /dev/null +++ b/tfw/internals/crypto/__init__.py @@ -0,0 +1,2 @@ +from .authentication import sign_message, verify_message +from .key_manager import KeyManager diff --git a/tfw/internals/crypto.py b/tfw/internals/crypto/authentication.py similarity index 69% rename from tfw/internals/crypto.py rename to tfw/internals/crypto/authentication.py index 04aff16..26e9e11 100644 --- a/tfw/internals/crypto.py +++ b/tfw/internals/crypto/authentication.py @@ -1,10 +1,6 @@ from functools import wraps from base64 import b64encode, b64decode from copy import deepcopy -from hashlib import md5 -from os import urandom, chmod -from os.path import exists -from stat import S_IRUSR, S_IWUSR, S_IXUSR from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.hashes import SHA256 @@ -12,12 +8,6 @@ from cryptography.hazmat.primitives.hmac import HMAC as _HMAC from cryptography.exceptions import InvalidSignature from tfw.internals.networking import message_bytes -from tfw.internals.lazy import lazy_property -from tfw.config import TFWENV - - -def message_checksum(message): - return md5(message_bytes(message)).hexdigest() def sign_message(key, message): @@ -43,32 +33,6 @@ def verify_message(key, message): return False -class KeyManager: - def __init__(self): - self.keyfile = TFWENV.AUTH_KEY - if not exists(self.keyfile): - self._init_auth_key() - - @lazy_property - def auth_key(self): - with open(self.keyfile, 'rb') as ifile: - return ifile.read() - - def _init_auth_key(self): - key = self.generate_key() - with open(self.keyfile, 'wb') as ofile: - ofile.write(key) - self._chmod_700_keyfile() - return key - - @staticmethod - def generate_key(): - return urandom(32) - - def _chmod_700_keyfile(self): - chmod(self.keyfile, S_IRUSR | S_IWUSR | S_IXUSR) - - class HMAC: def __init__(self, key, message): self.key = key diff --git a/tfw/internals/crypto/key_manager.py b/tfw/internals/crypto/key_manager.py new file mode 100644 index 0000000..d177f11 --- /dev/null +++ b/tfw/internals/crypto/key_manager.py @@ -0,0 +1,48 @@ +from atexit import register +from tempfile import gettempdir +from os import urandom, chmod, remove +from os.path import exists, join +from stat import S_IRUSR, S_IWUSR, S_IXUSR + +from tfw.internals.lazy import lazy_property +from tfw.internals.ref_counter import RefCounter + + +KEYFILE = join(gettempdir(), 'tfw-auth.key') +LOCKFILE = join(gettempdir(), 'tfw-auth.lock') + + +class KeyManagerRefCounter(RefCounter): + def deallocate(self): + if exists(KEYFILE): + remove(KEYFILE) + + +class KeyManager: + keyfile = KEYFILE + refcounter = KeyManagerRefCounter(LOCKFILE) + + def __init__(self): + if not exists(self.keyfile): + self._init_auth_key() + + @lazy_property + def auth_key(self): + with open(self.keyfile, 'rb') as ifile: + return ifile.read() + + def _init_auth_key(self): + key = self.generate_key() + with open(self.keyfile, 'wb') as ofile: + ofile.write(key) + self._chmod_700_keyfile() + return key + + @staticmethod + def generate_key(): + return urandom(32) + + def _chmod_700_keyfile(self): + chmod(self.keyfile, S_IRUSR | S_IWUSR | S_IXUSR) + +register(KeyManager.refcounter.teardown_instance) diff --git a/tfw/internals/ref_counter/__init__.py b/tfw/internals/ref_counter/__init__.py new file mode 100644 index 0000000..25f2d50 --- /dev/null +++ b/tfw/internals/ref_counter/__init__.py @@ -0,0 +1 @@ +from .ref_counter import RefCounter diff --git a/tfw/internals/ref_counter/ref_counter.py b/tfw/internals/ref_counter/ref_counter.py new file mode 100644 index 0000000..093fd53 --- /dev/null +++ b/tfw/internals/ref_counter/ref_counter.py @@ -0,0 +1,40 @@ +from os import remove +from fcntl import flock, LOCK_EX, LOCK_UN + + +class RefCounter: + def __init__(self, lockpath): + self.lockpath = lockpath + self._lockfile = open(self.lockpath, 'a+') + flock(self._lockfile, LOCK_EX) + counter = self._read_counter() + self._write_counter(counter+1) + flock(self._lockfile, LOCK_UN) + + def _read_counter(self): + self._lockfile.seek(0) + try: + counter = int(self._lockfile.read()) + except ValueError: + counter = 0 + return counter + + def _write_counter(self, counter): + self._lockfile.seek(0) + self._lockfile.truncate() + self._lockfile.write(str(counter)) + self._lockfile.flush() + + def teardown_instance(self): + flock(self._lockfile, LOCK_EX) + counter = self._read_counter() + if counter <= 1: + remove(self.lockpath) + self.deallocate() + else: + self._write_counter(counter-1) + flock(self._lockfile, LOCK_UN) + self._lockfile.close() + + def deallocate(self): + pass diff --git a/tfw/internals/ref_counter/test_ref_counter.py b/tfw/internals/ref_counter/test_ref_counter.py new file mode 100644 index 0000000..0666ad8 --- /dev/null +++ b/tfw/internals/ref_counter/test_ref_counter.py @@ -0,0 +1,69 @@ +# pylint: disable=redefined-outer-name +from dataclasses import dataclass +from textwrap import dedent +from os import mkfifo +from os.path import join +from signal import SIGINT +from subprocess import DEVNULL, Popen, PIPE +from tempfile import TemporaryDirectory + +import pytest + +from .ref_counter import RefCounter + + +@dataclass +class CounterContext: + lockpath: str + pipepath: str + + @property + def source(self): + return dedent(f'''\ + from time import sleep + from atexit import register + from ref_counter import RefCounter + + counter = RefCounter('{self.lockpath}') + register(counter.teardown_instance) + print(flush=True) + while True: + sleep(1) + ''') + + +@pytest.fixture +def context(): + with TemporaryDirectory() as workdir: + pipepath = join(workdir, 'test.pipe') + mkfifo(pipepath) + yield CounterContext(join(workdir, 'test.lock'), pipepath) + +def test_increment_decrement(context): + counter, processes = 0, [] + for _ in range(5): + new_proc = Popen(['python3', '-c', context.source], stdout=PIPE, stderr=DEVNULL) + new_proc.stdout.readline() + processes.append(new_proc) + counter += 1 + for proc in processes: + with open(context.lockpath, 'r') as lock: + assert lock.read() == str(counter) + counter -= 1 + proc.send_signal(SIGINT) + proc.wait() + +def test_deallocate(context): + state = False + def trigger(): + nonlocal state + state = True + refcounters = [] + for _ in range(32): + new_refc = RefCounter(context.lockpath) + new_refc.deallocate = trigger + refcounters.append(new_refc) + for refc in refcounters: + assert not state + refc.teardown_instance() + assert state