diff --git a/pipe_io_server.py b/pipe_io_server.py index 71241c6..9316c50 100644 --- a/pipe_io_server.py +++ b/pipe_io_server.py @@ -4,33 +4,37 @@ from os import mkfifo, remove from os.path import exists, join from signal import signal, SIGTERM, SIGINT from secrets import token_urlsafe +from collections import namedtuple class PipeWriterThread(Thread): - def __init__(self, to_write_queue, pipe_path): + def __init__(self, pipe_path): super().__init__() - self._to_write_queue = to_write_queue self._pipe_path = pipe_path + self._write_queue = Queue() + + def write(self, message): + self._write_queue.put(message, block=True) def run(self): while True: - message = self._to_write_queue.get(block=True) + message = self._write_queue.get(block=True) if message is None: break with open(self._pipe_path, 'wb') as pipe: pipe.write(message) def stop(self): - self._to_write_queue.put(None) + self._write_queue.put(None) self.join() class PipeReaderThread(Thread): _stop_sequence = b'stop_reading' - def __init__(self, results_queue, pipe_path): + def __init__(self, pipe_path, message_handler): super().__init__() - self._results_queue = results_queue + self._message_handler = message_handler self._pipe_path = pipe_path def run(self): @@ -39,7 +43,7 @@ class PipeReaderThread(Thread): message = pipe.read() if message == self._stop_sequence: break - self._results_queue.put(message, block=True) + self._message_handler(message) def stop(self): with open(self._pipe_path, 'wb') as pipe: @@ -67,11 +71,12 @@ class PipeIOServer: self.in_pipe, self.out_pipe = in_pipe, out_pipe self._create_pipes() - self._message_queue = Queue() - self._io_threads = { - 'reader': PipeReaderThread(self._message_queue, self.in_pipe), - 'writer': PipeWriterThread(self._message_queue, self.out_pipe) + io_threads_dict = { + 'reader': PipeReaderThread(self.in_pipe, self._handle_message), + 'writer': PipeWriterThread(self.out_pipe) } + IOThreadsTuple = namedtuple('IOThreadsTuple', sorted(io_threads_dict.keys())) + self._io_threads = IOThreadsTuple(**io_threads_dict) def _create_pipes(self): if not self.in_pipe or not self.out_pipe: @@ -80,12 +85,15 @@ class PipeIOServer: self.out_pipe = join('/tmp', f'out_pipe_{pipe_id}') PipeHandler(self.in_pipe, self.out_pipe).recreate() + def _handle_message(self, message): + self._io_threads.writer.write(message) + def run(self): - for thread in self._io_threads.values(): + for thread in self._io_threads: thread.start() def stop(self): - for thread in self._io_threads.values(): + for thread in self._io_threads: thread.stop() PipeHandler(self.in_pipe, self.out_pipe).remove()