# pylint: disable=redefined-outer-name
from os.path import join
from secrets import token_urlsafe
from random import randint
from tempfile import TemporaryDirectory
from contextlib import suppress

import pytest
from tornado.ioloop import IOLoop

from tfw.internals.networking import ZMQListener, ZMQConnector


@pytest.fixture
def _listener_and_connector():
    with TemporaryDirectory() as tmpdir:
        down_sock = join(tmpdir, 'down')
        up_sock = join(tmpdir, 'up')
        server_downlink = f'ipc://{down_sock}'
        server_uplink = f'ipc://{up_sock}'

        listener = ZMQListener(server_downlink, server_uplink)
        connector = ZMQConnector(server_uplink, server_downlink)
        yield listener, connector
        listener.close()
        connector.close()


@pytest.fixture
def zmq_listener(_listener_and_connector):
    listener, _ = _listener_and_connector
    yield listener


@pytest.fixture
def zmq_connector(_listener_and_connector):
    _, connector = _listener_and_connector
    yield connector


def run_ioloop_once():
    # hack: we have to wait for the messages to get through
    # the network stack of the OS while the IOLoop is waiting
    # for them via select/epoll/kqueue
    IOLoop.current().call_later(0.1, IOLoop.current().stop)
    IOLoop.current().start()


@pytest.fixture
def test_messages():
    random_str = lambda: token_urlsafe(randint(4, 8))
    yield [
        {
            'key': random_str(),
            random_str(): randint(8192, 16384),
            random_str(): random_str(),
            random_str(): {
                random_str(): random_str(),
                random_str(): {random_str(): random_str()}
            },
            random_str(): [random_str(), random_str()]
        }
        for _ in range(randint(8, 16))
    ]


def wait_until_subscriber_connects(listener, connector):
    # Warning: you are better off without comprehending how this works
    # Reference: ZMQ PUB-SUB slow joiner problem

    # Wait until something can go through the connection
    dummy = {'key': '-'}
    while True:
        listener.send_message(dummy)
        with suppress(IOError):
            if connector.recv_message(block=False) == dummy:
                break
    # Throw away leftover messages from last while loop
    sentinel = {'key': '_'}
    listener.send_message(sentinel)
    while True:
        with suppress(IOError):
            if connector.recv_message(block=False) == sentinel:
                break


def test_server_downlink(zmq_listener, zmq_connector, test_messages):
    messages = []
    zmq_listener.register_callback(messages.append)

    for message in test_messages:
        zmq_connector.send_message(message)

    run_ioloop_once()

    assert messages == test_messages


def test_server_uplink(zmq_listener, zmq_connector, test_messages):
    messages = []
    zmq_connector.subscribe('')
    zmq_connector.register_callback(messages.append)

    for message in test_messages:
        zmq_listener.send_message(message)

    run_ioloop_once()

    assert messages == test_messages


def test_connector_downlink_subscribe(zmq_listener, zmq_connector):
    key1_messages = [{'key': '1', 'data': i} for i in range(randint(128, 256))]
    key2_messages = [{'key': '2', 'data': i} for i in range(randint(128, 256))]
    all_messages = key1_messages + key2_messages

    messages = []
    zmq_connector.subscribe('1')
    zmq_connector.register_callback(messages.append)

    for message in all_messages:
        zmq_listener.send_message(message)

    run_ioloop_once()

    assert messages == key1_messages
    assert all((msg not in messages for msg in key2_messages))


def test_listener_sync_recv(zmq_listener, zmq_connector, test_messages):
    for message in test_messages:
        zmq_connector.send_message(message)
        assert zmq_listener.recv_message() == message


def test_connector_sync_recv(zmq_listener, zmq_connector, test_messages):
    zmq_connector.subscribe('')
    wait_until_subscriber_connects(zmq_listener, zmq_connector)
    for message in test_messages:
        zmq_listener.send_message(message)
        assert zmq_connector.recv_message() == message


def test_sync_recv_raises_if_callback_is_registered(zmq_listener, zmq_connector):
    zmq_listener.register_callback(lambda msg: None)
    zmq_connector.register_callback(lambda msg: None)

    with pytest.raises(RuntimeError):
        zmq_listener.recv_message()

    with pytest.raises(RuntimeError):
        zmq_connector.recv_message()