diff --git a/tfw/internals/crypto/key_manager.py b/tfw/internals/crypto/key_manager.py index d177f11..4e56cad 100644 --- a/tfw/internals/crypto/key_manager.py +++ b/tfw/internals/crypto/key_manager.py @@ -7,7 +7,6 @@ from stat import S_IRUSR, S_IWUSR, S_IXUSR from tfw.internals.lazy import lazy_property from tfw.internals.ref_counter import RefCounter - KEYFILE = join(gettempdir(), 'tfw-auth.key') LOCKFILE = join(gettempdir(), 'tfw-auth.lock') @@ -45,4 +44,5 @@ class KeyManager: def _chmod_700_keyfile(self): chmod(self.keyfile, S_IRUSR | S_IWUSR | S_IXUSR) + register(KeyManager.refcounter.teardown_instance) diff --git a/tfw/internals/ref_counter/ref_counter.py b/tfw/internals/ref_counter/ref_counter.py index 093fd53..4ff7889 100644 --- a/tfw/internals/ref_counter/ref_counter.py +++ b/tfw/internals/ref_counter/ref_counter.py @@ -1,4 +1,5 @@ from os import remove +from contextlib import contextmanager from fcntl import flock, LOCK_EX, LOCK_UN @@ -6,9 +7,14 @@ class RefCounter: def __init__(self, lockpath): self.lockpath = lockpath self._lockfile = open(self.lockpath, 'a+') + with self.locked(): + counter = self._read_counter() + self._write_counter(counter+1) + + @contextmanager + def locked(self): flock(self._lockfile, LOCK_EX) - counter = self._read_counter() - self._write_counter(counter+1) + yield flock(self._lockfile, LOCK_UN) def _read_counter(self): @@ -26,14 +32,13 @@ class RefCounter: self._lockfile.flush() def teardown_instance(self): - flock(self._lockfile, LOCK_EX) - counter = self._read_counter() - if counter <= 1: - remove(self.lockpath) - self.deallocate() - else: - self._write_counter(counter-1) - flock(self._lockfile, LOCK_UN) + with self.locked(): + counter = self._read_counter() + if counter <= 1: + remove(self.lockpath) + self.deallocate() + else: + self._write_counter(counter-1) self._lockfile.close() def deallocate(self): diff --git a/tfw/internals/ref_counter/test_ref_counter.py b/tfw/internals/ref_counter/test_ref_counter.py index 0666ad8..3b563b9 100644 --- a/tfw/internals/ref_counter/test_ref_counter.py +++ b/tfw/internals/ref_counter/test_ref_counter.py @@ -39,6 +39,7 @@ def context(): mkfifo(pipepath) yield CounterContext(join(workdir, 'test.lock'), pipepath) + def test_increment_decrement(context): counter, processes = 0, [] for _ in range(5): @@ -53,6 +54,7 @@ def test_increment_decrement(context): proc.send_signal(SIGINT) proc.wait() + def test_deallocate(context): state = False def trigger():