Allow reassigning handle_message() before start()

This commit is contained in:
Kristóf Tóth 2019-06-23 18:22:04 +02:00
parent 94b248fd89
commit a1e3fe9813
2 changed files with 29 additions and 10 deletions

View File

@ -1,4 +1,3 @@
from abc import ABC, abstractmethod
from threading import Thread, Event
from typing import Callable
@ -8,15 +7,15 @@ from .pipe import Pipe
from .terminate_process_on_failure import terminate_process_on_failure
class PipeIOServer(ABC, Thread):
class PipeIOServer(Thread):
def __init__(self, in_pipe=None, out_pipe=None, permissions=0o600):
super().__init__(daemon=True)
self._in_pipe, self._out_pipe = in_pipe, out_pipe
self._create_pipes(permissions)
self._stop_event = Event()
self._reader_thread, self._writer_thread = self._create_io_threads()
self._io_threads = (self._reader_thread, self._writer_thread)
self._on_stop = lambda: None
self._reader_thread, self._writer_thread = None, None
self._io_threads = None
def _create_pipes(self, permissions):
Pipe(self.in_pipe).recreate(permissions)
@ -30,12 +29,6 @@ class PipeIOServer(ABC, Thread):
def out_pipe(self):
return self._out_pipe
def _create_io_threads(self):
reader_thread = PipeReaderThread(self.in_pipe, self._stop_event, self.handle_message)
writer_thread = PipeWriterThread(self.out_pipe, self._stop_event)
return reader_thread, writer_thread
@abstractmethod
def handle_message(self, message):
raise NotImplementedError()
@ -44,11 +37,24 @@ class PipeIOServer(ABC, Thread):
@terminate_process_on_failure
def run(self):
self._init_io_threads()
for thread in self._io_threads:
thread.start()
self._stop_event.wait()
self._stop_threads()
def _init_io_threads(self):
self._reader_thread = PipeReaderThread(
self.in_pipe,
self._stop_event,
self.handle_message
)
self._writer_thread = PipeWriterThread(
self.out_pipe,
self._stop_event
)
self._io_threads = (self._reader_thread, self._writer_thread)
def stop(self):
self._stop_event.set()
if self.is_alive():

View File

@ -11,6 +11,7 @@ from json import dumps, loads
import pytest
from echo_server import EchoPipeIOServer
from pipe_io_server import PipeIOServer
@pytest.fixture
@ -156,3 +157,15 @@ def test_json_io(io_pipes):
}
io_pipes.send_message(dumps(test_data).encode())
assert loads(io_pipes.recv()) == test_data
def test_assign_message_handler():
pipe_io = build_pipe_io_server(PipeIOServer)
pipe_io.handle_message = lambda msg: pipe_io.send_message(msg * 2)
pipe_io.start()
with IOPipes(pipe_io.in_pipe, pipe_io.out_pipe) as io_pipes:
for _ in range(100):
test_data = token_urlsafe(32).encode()
io_pipes.send_message(test_data)
assert io_pipes.recv() == test_data * 2
pipe_io.stop()