diff --git a/pipe_io_server/pipe_reader_thread.py b/pipe_io_server/pipe_reader_thread.py index c9488fd..8d2c32c 100644 --- a/pipe_io_server/pipe_reader_thread.py +++ b/pipe_io_server/pipe_reader_thread.py @@ -1,7 +1,6 @@ from threading import Thread -from os import kill, getpid -from signal import SIGTERM -from traceback import print_exc + +from .terminate_process_on_failure import terminate_process_on_failure class PipeReaderThread(Thread): @@ -12,17 +11,14 @@ class PipeReaderThread(Thread): self._message_handler = message_handler self._pipe_path = pipe_path + @terminate_process_on_failure def run(self): while True: with open(self._pipe_path, 'rb') as pipe: message = pipe.read() if message == self._stop_sequence: break - try: - self._message_handler(message) - except: # pylint: disable=bare-except - print_exc() - kill(getpid(), SIGTERM) + self._message_handler(message) def stop(self): with open(self._pipe_path, 'wb') as pipe: diff --git a/pipe_io_server/pipe_writer_thread.py b/pipe_io_server/pipe_writer_thread.py index 39a8161..7d9bbe8 100644 --- a/pipe_io_server/pipe_writer_thread.py +++ b/pipe_io_server/pipe_writer_thread.py @@ -1,6 +1,8 @@ from threading import Thread from queue import Queue +from .terminate_process_on_failure import terminate_process_on_failure + class PipeWriterThread(Thread): def __init__(self, pipe_path): @@ -11,6 +13,7 @@ class PipeWriterThread(Thread): def write(self, message): self._write_queue.put(message, block=True) + @terminate_process_on_failure def run(self): while True: message = self._write_queue.get(block=True) diff --git a/pipe_io_server/terminate_process_on_failure.py b/pipe_io_server/terminate_process_on_failure.py new file mode 100644 index 0000000..7a0804c --- /dev/null +++ b/pipe_io_server/terminate_process_on_failure.py @@ -0,0 +1,15 @@ +from functools import wraps +from os import kill, getpid +from signal import SIGTERM +from traceback import print_exc + + +def terminate_process_on_failure(fun): + @wraps(fun) + def wrapper(*args, **kwargs): + try: + return fun(*args, **kwargs) + except: # pylint: disable=bare-except + print_exc() + kill(getpid(), SIGTERM) + return wrapper