diff --git a/lib/tfw/event_handlers/event_handler_factory.py b/lib/tfw/event_handlers/event_handler_factory.py index c3a6b18..dc03c46 100644 --- a/lib/tfw/event_handlers/event_handler_factory.py +++ b/lib/tfw/event_handlers/event_handler_factory.py @@ -4,25 +4,34 @@ from .event_handler import EventHandler class EventHandlerFactoryBase: - def build(self, event_handler, *, keys=None): - analyzer = EventHandlerAnalyzer(event_handler, keys) - event_handler = self._build_from_callable(analyzer) - event_handler.start() - return event_handler - - def _build_from_callable(self, analyzer): + def build(self, event_handler, *, keys=None, event_handler_type=EventHandler): + builder = EventHandlerBuilder(event_handler, keys, event_handler_type) server_connector = self._build_server_connector() - server_connector.subscribe(analyzer.keys) - event_handler = EventHandler(server_connector) - event_handler.handle_event = analyzer.handle_event + real_event_handler = builder.build(server_connector) + event_handler.server_connector = server_connector with suppress(AttributeError): - event_handler.cleanup = analyzer.cleanup - return event_handler + event_handler.start() + real_event_handler.start() + return real_event_handler def _build_server_connector(self): raise NotImplementedError() +class EventHandlerBuilder: + def __init__(self, event_handler, supplied_keys, event_handler_type): + self._analyzer = EventHandlerAnalyzer(event_handler, supplied_keys) + self._event_handler_type = event_handler_type + + def build(self, server_connector): + server_connector.subscribe(*self._analyzer.keys) + event_handler = self._event_handler_type(server_connector) + event_handler.handle_event = self._analyzer.handle_event + with suppress(AttributeError): + event_handler.cleanup = self._analyzer.cleanup + return event_handler + + class EventHandlerAnalyzer: def __init__(self, event_handler, supplied_keys): self._event_handler = event_handler diff --git a/lib/tfw/event_handlers/test_event_handler.py b/lib/tfw/event_handlers/test_event_handler.py index 2005f51..edde537 100644 --- a/lib/tfw/event_handlers/test_event_handler.py +++ b/lib/tfw/event_handlers/test_event_handler.py @@ -1,4 +1,4 @@ -# pylint: disable=redefined-outer-name +# pylint: disable=redefined-outer-name,attribute-defined-outside-init from secrets import token_urlsafe from random import randint @@ -18,16 +18,16 @@ class MockServerConnector: self.keys = [] self._on_message = None - def register_callback(self, callback): - self._on_message = callback - def simulate_message(self, message): self._on_message(message) - def subscribe(self, keys): + def register_callback(self, callback): + self._on_message = callback + + def subscribe(self, *keys): self.keys.extend(keys) - def unsubscribe(self, keys): + def unsubscribe(self, *keys): for key in keys: self.keys.remove(key) @@ -38,17 +38,31 @@ class MockServerConnector: pass -class MockEventHandler: +class MockEventHandlerStub: def __init__(self): + self.server_connector = None + self.last_message = None self.cleaned_up = False + self.started = False - def handle_event(self, message, server_connector): - pass + def start(self): + self.started = True def cleanup(self): self.cleaned_up = True +class MockEventHandler(MockEventHandlerStub): + # pylint: disable=unused-argument + def handle_event(self, message, server_connector): + self.last_message = message + + +class MockCallable(MockEventHandlerStub): + def __call__(self, message, server_connector): + self.last_message = message + + @pytest.fixture def test_msg(): yield token_urlsafe(randint(16, 64)) @@ -63,12 +77,16 @@ def test_keys(): def test_build_from_object(test_keys, test_msg): - mock_eh = MockEventHandler() - def test_handle_event(message, server_connector): + mock_eh = MockEventHandlerStub() + def handle_event(message, server_connector): raise RuntimeError(message, server_connector.keys) - mock_eh.handle_event = test_handle_event + mock_eh.handle_event = handle_event + + assert not mock_eh.started eh = MockEventHandlerFactory().build(mock_eh, keys=test_keys) + assert mock_eh.started + assert mock_eh.server_connector is eh.server_connector with pytest.raises(RuntimeError) as err: eh.server_connector.simulate_message(test_msg) msg, keys = err.args @@ -79,34 +97,52 @@ def test_build_from_object(test_keys, test_msg): assert mock_eh.cleaned_up -def test_build_from_object_with_keys(test_keys): +def test_build_from_object_with_keys(test_keys, test_msg): mock_eh = MockEventHandler() - mock_eh.keys = test_keys # pylint: disable=attribute-defined-outside-init + mock_eh.keys = test_keys + + assert not mock_eh.started eh = MockEventHandlerFactory().build(mock_eh) + assert mock_eh.server_connector.keys == test_keys + assert eh.server_connector is mock_eh.server_connector + assert mock_eh.started + assert not mock_eh.last_message + eh.server_connector.simulate_message(test_msg) + assert mock_eh.last_message == test_msg assert not mock_eh.cleaned_up EventHandler.stop_all_instances() assert mock_eh.cleaned_up - assert eh.server_connector.keys == test_keys + + +def test_build_from_simple_object(test_keys, test_msg): + class SimpleMockEventHandler: + # pylint: disable=no-self-use + def handle_event(self, message, server_connector): + raise RuntimeError(message, server_connector) + + mock_eh = SimpleMockEventHandler() + eh = MockEventHandlerFactory().build(mock_eh, keys=test_keys) + + with pytest.raises(RuntimeError) as err: + eh.server_connector.simulate_message(test_msg) + msg, keys = err.args + assert msg == test_msg + assert keys == test_keys def test_build_from_callable(test_keys, test_msg): - class SomeCallable: - def __init__(self): - self.message = None - self.cleaned_up = False - def __call__(self, message, server_connector): - self.message = message - def cleanup(self): - self.cleaned_up = True + mock_eh = MockCallable() - mock_eh = SomeCallable() + assert not mock_eh.started eh = MockEventHandlerFactory().build(mock_eh, keys=test_keys) + assert mock_eh.started + assert mock_eh.server_connector is eh.server_connector assert eh.server_connector.keys == test_keys - assert not mock_eh.message + assert not mock_eh.last_message eh.server_connector.simulate_message(test_msg) - assert mock_eh.message == test_msg + assert mock_eh.last_message == test_msg assert not mock_eh.cleaned_up eh.stop() assert mock_eh.cleaned_up @@ -133,7 +169,7 @@ def test_build_from_lambda(test_keys, test_msg): eh.server_connector.simulate_message(test_msg) -def test_build_raises_if_no_key(): +def test_build_raises_if_no_key(test_keys): eh = MockEventHandler() with pytest.raises(ValueError): MockEventHandlerFactory().build(eh) @@ -145,3 +181,6 @@ def test_build_raises_if_no_key(): with pytest.raises(ValueError): MockEventHandlerFactory().build(lambda msg, sc: None) + + eh.keys = test_keys + MockEventHandlerFactory().build(eh)