From 31fea13a9aac10456bbdd7d2e7b3d389595709fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?B=C3=A1lint=20Bokros?= Date: Fri, 2 Feb 2018 17:30:26 +0100 Subject: [PATCH] Move serialization in one place --- lib/tfw/event_handler_base.py | 12 +++++------- lib/tfw/message_sender.py | 6 ++---- lib/tfw/networking/event_handler_connector.py | 4 ++-- lib/tfw/networking/serialization.py | 18 ++++++++++++++++++ lib/tfw/networking/server_connector.py | 3 ++- lib/tfw/networking/zmq_websocket_handler.py | 7 ++++--- 6 files changed, 33 insertions(+), 17 deletions(-) create mode 100644 lib/tfw/networking/serialization.py diff --git a/lib/tfw/event_handler_base.py b/lib/tfw/event_handler_base.py index c5c7bf2..dbff745 100644 --- a/lib/tfw/event_handler_base.py +++ b/lib/tfw/event_handler_base.py @@ -1,5 +1,6 @@ import json +from .networking.serialization import deserialize_all from .networking.server_connector import ServerConnector @@ -13,12 +14,11 @@ class EventHandlerBase: self.server_connector.register_callback(self.event_handler_callback) def event_handler_callback(self, msg_parts): - anchor, message = msg_parts + anchor, message = deserialize_all(*msg_parts) data_json = json.loads(message) - response = self.handle_event(anchor, data_json) if anchor != b'reset' else self.handle_reset(data_json) + response = self.handle_event(anchor, data_json) if anchor != 'reset' else self.handle_reset(data_json) if response is None: return - encoded_response = json.dumps(response).encode('utf-8') - self.server_connector.send(anchor, encoded_response) + self.server_connector.send(anchor, json.dumps(response)) def handle_event(self, anchor, data_json): raise NotImplementedError @@ -27,13 +27,11 @@ class EventHandlerBase: return None def message_other(self, anchor, data): - encoded_anchor = anchor.encode('utf-8') message = { 'anchor': anchor, 'data': data } - encoded_message = json.dumps(message).encode('utf-8') - self.server_connector.send(encoded_anchor, encoded_message) + self.server_connector.send(anchor, json.dumps(message)) def subscribe(self, anchor): if anchor not in self.subscriptions: diff --git a/lib/tfw/message_sender.py b/lib/tfw/message_sender.py index 138d25d..6b44f65 100644 --- a/lib/tfw/message_sender.py +++ b/lib/tfw/message_sender.py @@ -5,10 +5,8 @@ from .networking.server_connector import ServerUplinkConnector class MessageSender: - def __init__(self, custom_anchor=None): + def __init__(self, custom_anchor: str = None): self.server_connector = ServerUplinkConnector() - if isinstance(custom_anchor, bytes): - custom_anchor = custom_anchor.decode('utf-8') self.anchor = custom_anchor or 'message' def send(self, originator, message): @@ -21,4 +19,4 @@ class MessageSender: 'anchor': self.anchor, 'data': data } - self.server_connector.send(*[frame.encode('utf-8') for frame in (self.anchor, json.dumps(response))]) + self.server_connector.send(self.anchor, json.dumps(response)) diff --git a/lib/tfw/networking/event_handler_connector.py b/lib/tfw/networking/event_handler_connector.py index e5126a4..3099ff6 100644 --- a/lib/tfw/networking/event_handler_connector.py +++ b/lib/tfw/networking/event_handler_connector.py @@ -2,6 +2,7 @@ import zmq from zmq.eventloop import ioloop from zmq.eventloop.zmqstream import ZMQStream +from .serialization import serialize_all from ..config import PUBLISHER_PORT, RECEIVER_PORT from ..config.logs import logging log = logging.getLogger(__name__) @@ -40,5 +41,4 @@ class EventHandlerConnector(EventHandlerDownlinkConnector, EventHandlerUplinkCon def send_message(self, message: str, anchor: str = None): if not anchor: anchor = parse_anchor_from_message(message) - encoded_message = [part.encode('utf-8') for part in (anchor, message)] - self._zmq_pub_socket.send_multipart(encoded_message) + self._zmq_pub_socket.send_multipart(serialize_all(anchor, message)) diff --git a/lib/tfw/networking/serialization.py b/lib/tfw/networking/serialization.py new file mode 100644 index 0000000..33120e0 --- /dev/null +++ b/lib/tfw/networking/serialization.py @@ -0,0 +1,18 @@ +def encode_if_needed(value): + if isinstance(value, str): + value = value.encode('utf-8') + return value + + +def decode_if_needed(value): + if isinstance(value, (bytes, bytearray)): + value = value.decode('utf-8') + return value + + +def serialize_all(*args): + return [encode_if_needed(a) for a in args] + + +def deserialize_all(*args): + return [decode_if_needed(a) for a in args] diff --git a/lib/tfw/networking/server_connector.py b/lib/tfw/networking/server_connector.py index c1597a6..a1f03ee 100644 --- a/lib/tfw/networking/server_connector.py +++ b/lib/tfw/networking/server_connector.py @@ -3,6 +3,7 @@ from functools import partial from zmq.eventloop import ioloop from zmq.eventloop.zmqstream import ZMQStream +from .serialization import serialize_all from ..config import PUBLISHER_PORT, RECEIVER_PORT from ..util import ZMQConnectorBase @@ -29,7 +30,7 @@ class ServerUplinkConnector(ZMQConnectorBase): self._zmq_push_socket.connect('tcp://localhost:{}'.format(RECEIVER_PORT)) def send(self, anchor, response): - self._zmq_push_socket.send_multipart([anchor, response]) + self._zmq_push_socket.send_multipart(serialize_all(anchor, response)) class ServerConnector(ServerUplinkConnector, ServerDownlinkConnector): diff --git a/lib/tfw/networking/zmq_websocket_handler.py b/lib/tfw/networking/zmq_websocket_handler.py index 3a5e947..7f58ff8 100644 --- a/lib/tfw/networking/zmq_websocket_handler.py +++ b/lib/tfw/networking/zmq_websocket_handler.py @@ -1,6 +1,7 @@ import json from tornado.websocket import WebSocketHandler +from .serialization import deserialize_all from ..util import parse_anchor_from_message from .event_handler_connector import EventHandlerConnector from ..config.logs import logging @@ -20,9 +21,9 @@ class ZMQWebSocketHandler(WebSocketHandler): self._event_handler_connector.register_callback(self.zmq_callback) def zmq_callback(self, msg_parts): - anchor, data = msg_parts - log.debug('Received on pull socket: {}'.format(data.decode())) - self.write_message(data.decode()) + anchor, data = deserialize_all(*msg_parts) + log.debug('Received on pull socket: {}'.format(data)) + self.write_message(data) def on_message(self, message): log.debug('Received on WebSocket: {}'.format(message))