mirror of
				https://github.com/avatao-content/baseimage-tutorial-framework
				synced 2025-11-04 07:32:55 +00:00 
			
		
		
		
	Implement reference counting mechanism
This commit is contained in:
		
				
					committed by
					
						
						therealkrispet
					
				
			
			
				
	
			
			
			
						parent
						
							911831fdb1
						
					
				
				
					commit
					25bd9aa0f3
				
			
							
								
								
									
										2
									
								
								tfw/internals/crypto/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								tfw/internals/crypto/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,2 @@
 | 
			
		||||
from .authentication import sign_message, verify_message
 | 
			
		||||
from .key_manager import KeyManager
 | 
			
		||||
@@ -1,10 +1,6 @@
 | 
			
		||||
from functools import wraps
 | 
			
		||||
from base64 import b64encode, b64decode
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
from hashlib import md5
 | 
			
		||||
from os import urandom, chmod
 | 
			
		||||
from os.path import exists
 | 
			
		||||
from stat import S_IRUSR, S_IWUSR, S_IXUSR
 | 
			
		||||
 | 
			
		||||
from cryptography.hazmat.backends import default_backend
 | 
			
		||||
from cryptography.hazmat.primitives.hashes import SHA256
 | 
			
		||||
@@ -12,12 +8,6 @@ from cryptography.hazmat.primitives.hmac import HMAC as _HMAC
 | 
			
		||||
from cryptography.exceptions import InvalidSignature
 | 
			
		||||
 | 
			
		||||
from tfw.internals.networking import message_bytes
 | 
			
		||||
from tfw.internals.lazy import lazy_property
 | 
			
		||||
from tfw.config import TFWENV
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def message_checksum(message):
 | 
			
		||||
    return md5(message_bytes(message)).hexdigest()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def sign_message(key, message):
 | 
			
		||||
@@ -43,32 +33,6 @@ def verify_message(key, message):
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class KeyManager:
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.keyfile = TFWENV.AUTH_KEY
 | 
			
		||||
        if not exists(self.keyfile):
 | 
			
		||||
            self._init_auth_key()
 | 
			
		||||
 | 
			
		||||
    @lazy_property
 | 
			
		||||
    def auth_key(self):
 | 
			
		||||
        with open(self.keyfile, 'rb') as ifile:
 | 
			
		||||
            return ifile.read()
 | 
			
		||||
 | 
			
		||||
    def _init_auth_key(self):
 | 
			
		||||
        key = self.generate_key()
 | 
			
		||||
        with open(self.keyfile, 'wb') as ofile:
 | 
			
		||||
            ofile.write(key)
 | 
			
		||||
        self._chmod_700_keyfile()
 | 
			
		||||
        return key
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def generate_key():
 | 
			
		||||
        return urandom(32)
 | 
			
		||||
 | 
			
		||||
    def _chmod_700_keyfile(self):
 | 
			
		||||
        chmod(self.keyfile, S_IRUSR | S_IWUSR | S_IXUSR)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class HMAC:
 | 
			
		||||
    def __init__(self, key, message):
 | 
			
		||||
        self.key = key
 | 
			
		||||
							
								
								
									
										48
									
								
								tfw/internals/crypto/key_manager.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								tfw/internals/crypto/key_manager.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,48 @@
 | 
			
		||||
from atexit import register
 | 
			
		||||
from tempfile import gettempdir
 | 
			
		||||
from os import urandom, chmod, remove
 | 
			
		||||
from os.path import exists, join
 | 
			
		||||
from stat import S_IRUSR, S_IWUSR, S_IXUSR
 | 
			
		||||
 | 
			
		||||
from tfw.internals.lazy import lazy_property
 | 
			
		||||
from tfw.internals.ref_counter import RefCounter
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
KEYFILE = join(gettempdir(), 'tfw-auth.key')
 | 
			
		||||
LOCKFILE = join(gettempdir(), 'tfw-auth.lock')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class KeyManagerRefCounter(RefCounter):
 | 
			
		||||
    def deallocate(self):
 | 
			
		||||
        if exists(KEYFILE):
 | 
			
		||||
            remove(KEYFILE)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class KeyManager:
 | 
			
		||||
    keyfile = KEYFILE
 | 
			
		||||
    refcounter = KeyManagerRefCounter(LOCKFILE)
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        if not exists(self.keyfile):
 | 
			
		||||
            self._init_auth_key()
 | 
			
		||||
 | 
			
		||||
    @lazy_property
 | 
			
		||||
    def auth_key(self):
 | 
			
		||||
        with open(self.keyfile, 'rb') as ifile:
 | 
			
		||||
            return ifile.read()
 | 
			
		||||
 | 
			
		||||
    def _init_auth_key(self):
 | 
			
		||||
        key = self.generate_key()
 | 
			
		||||
        with open(self.keyfile, 'wb') as ofile:
 | 
			
		||||
            ofile.write(key)
 | 
			
		||||
        self._chmod_700_keyfile()
 | 
			
		||||
        return key
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def generate_key():
 | 
			
		||||
        return urandom(32)
 | 
			
		||||
 | 
			
		||||
    def _chmod_700_keyfile(self):
 | 
			
		||||
        chmod(self.keyfile, S_IRUSR | S_IWUSR | S_IXUSR)
 | 
			
		||||
 | 
			
		||||
register(KeyManager.refcounter.teardown_instance)
 | 
			
		||||
							
								
								
									
										1
									
								
								tfw/internals/ref_counter/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								tfw/internals/ref_counter/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1 @@
 | 
			
		||||
from .ref_counter import RefCounter
 | 
			
		||||
							
								
								
									
										40
									
								
								tfw/internals/ref_counter/ref_counter.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								tfw/internals/ref_counter/ref_counter.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,40 @@
 | 
			
		||||
from os import remove
 | 
			
		||||
from fcntl import flock, LOCK_EX, LOCK_UN
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RefCounter:
 | 
			
		||||
    def __init__(self, lockpath):
 | 
			
		||||
        self.lockpath = lockpath
 | 
			
		||||
        self._lockfile = open(self.lockpath, 'a+')
 | 
			
		||||
        flock(self._lockfile, LOCK_EX)
 | 
			
		||||
        counter = self._read_counter()
 | 
			
		||||
        self._write_counter(counter+1)
 | 
			
		||||
        flock(self._lockfile, LOCK_UN)
 | 
			
		||||
 | 
			
		||||
    def _read_counter(self):
 | 
			
		||||
        self._lockfile.seek(0)
 | 
			
		||||
        try:
 | 
			
		||||
            counter = int(self._lockfile.read())
 | 
			
		||||
        except ValueError:
 | 
			
		||||
            counter = 0
 | 
			
		||||
        return counter
 | 
			
		||||
 | 
			
		||||
    def _write_counter(self, counter):
 | 
			
		||||
        self._lockfile.seek(0)
 | 
			
		||||
        self._lockfile.truncate()
 | 
			
		||||
        self._lockfile.write(str(counter))
 | 
			
		||||
        self._lockfile.flush()
 | 
			
		||||
 | 
			
		||||
    def teardown_instance(self):
 | 
			
		||||
        flock(self._lockfile, LOCK_EX)
 | 
			
		||||
        counter = self._read_counter()
 | 
			
		||||
        if counter <= 1:
 | 
			
		||||
            remove(self.lockpath)
 | 
			
		||||
            self.deallocate()
 | 
			
		||||
        else:
 | 
			
		||||
            self._write_counter(counter-1)
 | 
			
		||||
            flock(self._lockfile, LOCK_UN)
 | 
			
		||||
        self._lockfile.close()
 | 
			
		||||
 | 
			
		||||
    def deallocate(self):
 | 
			
		||||
        pass
 | 
			
		||||
							
								
								
									
										69
									
								
								tfw/internals/ref_counter/test_ref_counter.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								tfw/internals/ref_counter/test_ref_counter.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,69 @@
 | 
			
		||||
# pylint: disable=redefined-outer-name
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from textwrap import dedent
 | 
			
		||||
from os import mkfifo
 | 
			
		||||
from os.path import join
 | 
			
		||||
from signal import SIGINT
 | 
			
		||||
from subprocess import DEVNULL, Popen, PIPE
 | 
			
		||||
from tempfile import TemporaryDirectory
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
from .ref_counter import RefCounter
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class CounterContext:
 | 
			
		||||
    lockpath: str
 | 
			
		||||
    pipepath: str
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def source(self):
 | 
			
		||||
        return dedent(f'''\
 | 
			
		||||
            from time import sleep
 | 
			
		||||
            from atexit import register
 | 
			
		||||
            from ref_counter import RefCounter
 | 
			
		||||
 | 
			
		||||
            counter = RefCounter('{self.lockpath}')
 | 
			
		||||
            register(counter.teardown_instance)
 | 
			
		||||
            print(flush=True)
 | 
			
		||||
            while True:
 | 
			
		||||
                sleep(1)
 | 
			
		||||
        ''')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
def context():
 | 
			
		||||
    with TemporaryDirectory() as workdir:
 | 
			
		||||
        pipepath = join(workdir, 'test.pipe')
 | 
			
		||||
        mkfifo(pipepath)
 | 
			
		||||
        yield CounterContext(join(workdir, 'test.lock'), pipepath)
 | 
			
		||||
 | 
			
		||||
def test_increment_decrement(context):
 | 
			
		||||
    counter, processes = 0, []
 | 
			
		||||
    for _ in range(5):
 | 
			
		||||
        new_proc = Popen(['python3', '-c', context.source], stdout=PIPE, stderr=DEVNULL)
 | 
			
		||||
        new_proc.stdout.readline()
 | 
			
		||||
        processes.append(new_proc)
 | 
			
		||||
        counter += 1
 | 
			
		||||
    for proc in processes:
 | 
			
		||||
        with open(context.lockpath, 'r') as lock:
 | 
			
		||||
            assert lock.read() == str(counter)
 | 
			
		||||
            counter -= 1
 | 
			
		||||
        proc.send_signal(SIGINT)
 | 
			
		||||
        proc.wait()
 | 
			
		||||
 | 
			
		||||
def test_deallocate(context):
 | 
			
		||||
    state = False
 | 
			
		||||
    def trigger():
 | 
			
		||||
        nonlocal state
 | 
			
		||||
        state = True
 | 
			
		||||
    refcounters = []
 | 
			
		||||
    for _ in range(32):
 | 
			
		||||
        new_refc = RefCounter(context.lockpath)
 | 
			
		||||
        new_refc.deallocate = trigger
 | 
			
		||||
        refcounters.append(new_refc)
 | 
			
		||||
    for refc in refcounters:
 | 
			
		||||
        assert not state
 | 
			
		||||
        refc.teardown_instance()
 | 
			
		||||
    assert state
 | 
			
		||||
		Reference in New Issue
	
	Block a user