From a1e3fe9813fb51189ef80a47c7e4b5e779f08bb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krist=C3=B3f=20T=C3=B3th?= Date: Sun, 23 Jun 2019 18:22:04 +0200 Subject: [PATCH] Allow reassigning handle_message() before start() --- pipe_io_server/pipe_io_server.py | 26 ++++++++++++++++---------- test_echo_server.py | 13 +++++++++++++ 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/pipe_io_server/pipe_io_server.py b/pipe_io_server/pipe_io_server.py index 2715f40..92eef98 100644 --- a/pipe_io_server/pipe_io_server.py +++ b/pipe_io_server/pipe_io_server.py @@ -1,4 +1,3 @@ -from abc import ABC, abstractmethod from threading import Thread, Event from typing import Callable @@ -8,15 +7,15 @@ from .pipe import Pipe from .terminate_process_on_failure import terminate_process_on_failure -class PipeIOServer(ABC, Thread): +class PipeIOServer(Thread): def __init__(self, in_pipe=None, out_pipe=None, permissions=0o600): super().__init__(daemon=True) self._in_pipe, self._out_pipe = in_pipe, out_pipe self._create_pipes(permissions) self._stop_event = Event() - self._reader_thread, self._writer_thread = self._create_io_threads() - self._io_threads = (self._reader_thread, self._writer_thread) self._on_stop = lambda: None + self._reader_thread, self._writer_thread = None, None + self._io_threads = None def _create_pipes(self, permissions): Pipe(self.in_pipe).recreate(permissions) @@ -30,12 +29,6 @@ class PipeIOServer(ABC, Thread): def out_pipe(self): return self._out_pipe - def _create_io_threads(self): - reader_thread = PipeReaderThread(self.in_pipe, self._stop_event, self.handle_message) - writer_thread = PipeWriterThread(self.out_pipe, self._stop_event) - return reader_thread, writer_thread - - @abstractmethod def handle_message(self, message): raise NotImplementedError() @@ -44,11 +37,24 @@ class PipeIOServer(ABC, Thread): @terminate_process_on_failure def run(self): + self._init_io_threads() for thread in self._io_threads: thread.start() self._stop_event.wait() self._stop_threads() + def _init_io_threads(self): + self._reader_thread = PipeReaderThread( + self.in_pipe, + self._stop_event, + self.handle_message + ) + self._writer_thread = PipeWriterThread( + self.out_pipe, + self._stop_event + ) + self._io_threads = (self._reader_thread, self._writer_thread) + def stop(self): self._stop_event.set() if self.is_alive(): diff --git a/test_echo_server.py b/test_echo_server.py index d0437cb..56972db 100644 --- a/test_echo_server.py +++ b/test_echo_server.py @@ -11,6 +11,7 @@ from json import dumps, loads import pytest from echo_server import EchoPipeIOServer +from pipe_io_server import PipeIOServer @pytest.fixture @@ -156,3 +157,15 @@ def test_json_io(io_pipes): } io_pipes.send_message(dumps(test_data).encode()) assert loads(io_pipes.recv()) == test_data + + +def test_assign_message_handler(): + pipe_io = build_pipe_io_server(PipeIOServer) + pipe_io.handle_message = lambda msg: pipe_io.send_message(msg * 2) + pipe_io.start() + with IOPipes(pipe_io.in_pipe, pipe_io.out_pipe) as io_pipes: + for _ in range(100): + test_data = token_urlsafe(32).encode() + io_pipes.send_message(test_data) + assert io_pipes.recv() == test_data * 2 + pipe_io.stop()