Implement reference counting mechanism

This commit is contained in:
R. Richard 2019-07-31 16:30:06 +02:00 committed by therealkrispet
parent 911831fdb1
commit 25bd9aa0f3
6 changed files with 160 additions and 36 deletions

View File

@ -0,0 +1,2 @@
from .authentication import sign_message, verify_message
from .key_manager import KeyManager

View File

@ -1,10 +1,6 @@
from functools import wraps from functools import wraps
from base64 import b64encode, b64decode from base64 import b64encode, b64decode
from copy import deepcopy from copy import deepcopy
from hashlib import md5
from os import urandom, chmod
from os.path import exists
from stat import S_IRUSR, S_IWUSR, S_IXUSR
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.hashes import SHA256 from cryptography.hazmat.primitives.hashes import SHA256
@ -12,12 +8,6 @@ from cryptography.hazmat.primitives.hmac import HMAC as _HMAC
from cryptography.exceptions import InvalidSignature from cryptography.exceptions import InvalidSignature
from tfw.internals.networking import message_bytes from tfw.internals.networking import message_bytes
from tfw.internals.lazy import lazy_property
from tfw.config import TFWENV
def message_checksum(message):
return md5(message_bytes(message)).hexdigest()
def sign_message(key, message): def sign_message(key, message):
@ -43,32 +33,6 @@ def verify_message(key, message):
return False return False
class KeyManager:
def __init__(self):
self.keyfile = TFWENV.AUTH_KEY
if not exists(self.keyfile):
self._init_auth_key()
@lazy_property
def auth_key(self):
with open(self.keyfile, 'rb') as ifile:
return ifile.read()
def _init_auth_key(self):
key = self.generate_key()
with open(self.keyfile, 'wb') as ofile:
ofile.write(key)
self._chmod_700_keyfile()
return key
@staticmethod
def generate_key():
return urandom(32)
def _chmod_700_keyfile(self):
chmod(self.keyfile, S_IRUSR | S_IWUSR | S_IXUSR)
class HMAC: class HMAC:
def __init__(self, key, message): def __init__(self, key, message):
self.key = key self.key = key

View File

@ -0,0 +1,48 @@
from atexit import register
from tempfile import gettempdir
from os import urandom, chmod, remove
from os.path import exists, join
from stat import S_IRUSR, S_IWUSR, S_IXUSR
from tfw.internals.lazy import lazy_property
from tfw.internals.ref_counter import RefCounter
KEYFILE = join(gettempdir(), 'tfw-auth.key')
LOCKFILE = join(gettempdir(), 'tfw-auth.lock')
class KeyManagerRefCounter(RefCounter):
def deallocate(self):
if exists(KEYFILE):
remove(KEYFILE)
class KeyManager:
keyfile = KEYFILE
refcounter = KeyManagerRefCounter(LOCKFILE)
def __init__(self):
if not exists(self.keyfile):
self._init_auth_key()
@lazy_property
def auth_key(self):
with open(self.keyfile, 'rb') as ifile:
return ifile.read()
def _init_auth_key(self):
key = self.generate_key()
with open(self.keyfile, 'wb') as ofile:
ofile.write(key)
self._chmod_700_keyfile()
return key
@staticmethod
def generate_key():
return urandom(32)
def _chmod_700_keyfile(self):
chmod(self.keyfile, S_IRUSR | S_IWUSR | S_IXUSR)
register(KeyManager.refcounter.teardown_instance)

View File

@ -0,0 +1 @@
from .ref_counter import RefCounter

View File

@ -0,0 +1,40 @@
from os import remove
from fcntl import flock, LOCK_EX, LOCK_UN
class RefCounter:
def __init__(self, lockpath):
self.lockpath = lockpath
self._lockfile = open(self.lockpath, 'a+')
flock(self._lockfile, LOCK_EX)
counter = self._read_counter()
self._write_counter(counter+1)
flock(self._lockfile, LOCK_UN)
def _read_counter(self):
self._lockfile.seek(0)
try:
counter = int(self._lockfile.read())
except ValueError:
counter = 0
return counter
def _write_counter(self, counter):
self._lockfile.seek(0)
self._lockfile.truncate()
self._lockfile.write(str(counter))
self._lockfile.flush()
def teardown_instance(self):
flock(self._lockfile, LOCK_EX)
counter = self._read_counter()
if counter <= 1:
remove(self.lockpath)
self.deallocate()
else:
self._write_counter(counter-1)
flock(self._lockfile, LOCK_UN)
self._lockfile.close()
def deallocate(self):
pass

View File

@ -0,0 +1,69 @@
# pylint: disable=redefined-outer-name
from dataclasses import dataclass
from textwrap import dedent
from os import mkfifo
from os.path import join
from signal import SIGINT
from subprocess import DEVNULL, Popen, PIPE
from tempfile import TemporaryDirectory
import pytest
from .ref_counter import RefCounter
@dataclass
class CounterContext:
lockpath: str
pipepath: str
@property
def source(self):
return dedent(f'''\
from time import sleep
from atexit import register
from ref_counter import RefCounter
counter = RefCounter('{self.lockpath}')
register(counter.teardown_instance)
print(flush=True)
while True:
sleep(1)
''')
@pytest.fixture
def context():
with TemporaryDirectory() as workdir:
pipepath = join(workdir, 'test.pipe')
mkfifo(pipepath)
yield CounterContext(join(workdir, 'test.lock'), pipepath)
def test_increment_decrement(context):
counter, processes = 0, []
for _ in range(5):
new_proc = Popen(['python3', '-c', context.source], stdout=PIPE, stderr=DEVNULL)
new_proc.stdout.readline()
processes.append(new_proc)
counter += 1
for proc in processes:
with open(context.lockpath, 'r') as lock:
assert lock.read() == str(counter)
counter -= 1
proc.send_signal(SIGINT)
proc.wait()
def test_deallocate(context):
state = False
def trigger():
nonlocal state
state = True
refcounters = []
for _ in range(32):
new_refc = RefCounter(context.lockpath)
new_refc.deallocate = trigger
refcounters.append(new_refc)
for refc in refcounters:
assert not state
refc.teardown_instance()
assert state