Allow reassigning handle_message() before start()
This commit is contained in:
		| @@ -1,4 +1,3 @@ | |||||||
| from abc import ABC, abstractmethod |  | ||||||
| from threading import Thread, Event | from threading import Thread, Event | ||||||
| from typing import Callable | from typing import Callable | ||||||
|  |  | ||||||
| @@ -8,15 +7,15 @@ from .pipe import Pipe | |||||||
| from .terminate_process_on_failure import terminate_process_on_failure | from .terminate_process_on_failure import terminate_process_on_failure | ||||||
|  |  | ||||||
|  |  | ||||||
| class PipeIOServer(ABC, Thread): | class PipeIOServer(Thread): | ||||||
|     def __init__(self, in_pipe=None, out_pipe=None, permissions=0o600): |     def __init__(self, in_pipe=None, out_pipe=None, permissions=0o600): | ||||||
|         super().__init__(daemon=True) |         super().__init__(daemon=True) | ||||||
|         self._in_pipe, self._out_pipe = in_pipe, out_pipe |         self._in_pipe, self._out_pipe = in_pipe, out_pipe | ||||||
|         self._create_pipes(permissions) |         self._create_pipes(permissions) | ||||||
|         self._stop_event = Event() |         self._stop_event = Event() | ||||||
|         self._reader_thread, self._writer_thread = self._create_io_threads() |  | ||||||
|         self._io_threads = (self._reader_thread, self._writer_thread) |  | ||||||
|         self._on_stop = lambda: None |         self._on_stop = lambda: None | ||||||
|  |         self._reader_thread, self._writer_thread = None, None | ||||||
|  |         self._io_threads = None | ||||||
|  |  | ||||||
|     def _create_pipes(self, permissions): |     def _create_pipes(self, permissions): | ||||||
|         Pipe(self.in_pipe).recreate(permissions) |         Pipe(self.in_pipe).recreate(permissions) | ||||||
| @@ -30,12 +29,6 @@ class PipeIOServer(ABC, Thread): | |||||||
|     def out_pipe(self): |     def out_pipe(self): | ||||||
|         return self._out_pipe |         return self._out_pipe | ||||||
|  |  | ||||||
|     def _create_io_threads(self): |  | ||||||
|         reader_thread = PipeReaderThread(self.in_pipe, self._stop_event, self.handle_message) |  | ||||||
|         writer_thread = PipeWriterThread(self.out_pipe, self._stop_event) |  | ||||||
|         return reader_thread, writer_thread |  | ||||||
|  |  | ||||||
|     @abstractmethod |  | ||||||
|     def handle_message(self, message): |     def handle_message(self, message): | ||||||
|         raise NotImplementedError() |         raise NotImplementedError() | ||||||
|  |  | ||||||
| @@ -44,11 +37,24 @@ class PipeIOServer(ABC, Thread): | |||||||
|  |  | ||||||
|     @terminate_process_on_failure |     @terminate_process_on_failure | ||||||
|     def run(self): |     def run(self): | ||||||
|  |         self._init_io_threads() | ||||||
|         for thread in self._io_threads: |         for thread in self._io_threads: | ||||||
|             thread.start() |             thread.start() | ||||||
|         self._stop_event.wait() |         self._stop_event.wait() | ||||||
|         self._stop_threads() |         self._stop_threads() | ||||||
|  |  | ||||||
|  |     def _init_io_threads(self): | ||||||
|  |         self._reader_thread = PipeReaderThread( | ||||||
|  |             self.in_pipe, | ||||||
|  |             self._stop_event, | ||||||
|  |             self.handle_message | ||||||
|  |         ) | ||||||
|  |         self._writer_thread = PipeWriterThread( | ||||||
|  |             self.out_pipe, | ||||||
|  |             self._stop_event | ||||||
|  |         ) | ||||||
|  |         self._io_threads = (self._reader_thread, self._writer_thread) | ||||||
|  |  | ||||||
|     def stop(self): |     def stop(self): | ||||||
|         self._stop_event.set() |         self._stop_event.set() | ||||||
|         if self.is_alive(): |         if self.is_alive(): | ||||||
|   | |||||||
| @@ -11,6 +11,7 @@ from json import dumps, loads | |||||||
| import pytest | import pytest | ||||||
|  |  | ||||||
| from echo_server import EchoPipeIOServer | from echo_server import EchoPipeIOServer | ||||||
|  | from pipe_io_server import PipeIOServer | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.fixture | @pytest.fixture | ||||||
| @@ -156,3 +157,15 @@ def test_json_io(io_pipes): | |||||||
|         } |         } | ||||||
|         io_pipes.send_message(dumps(test_data).encode()) |         io_pipes.send_message(dumps(test_data).encode()) | ||||||
|         assert loads(io_pipes.recv()) == test_data |         assert loads(io_pipes.recv()) == test_data | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_assign_message_handler(): | ||||||
|  |     pipe_io = build_pipe_io_server(PipeIOServer) | ||||||
|  |     pipe_io.handle_message = lambda msg: pipe_io.send_message(msg * 2) | ||||||
|  |     pipe_io.start() | ||||||
|  |     with IOPipes(pipe_io.in_pipe, pipe_io.out_pipe) as io_pipes: | ||||||
|  |         for _ in range(100): | ||||||
|  |             test_data = token_urlsafe(32).encode() | ||||||
|  |             io_pipes.send_message(test_data) | ||||||
|  |             assert io_pipes.recv() == test_data * 2 | ||||||
|  |     pipe_io.stop() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user