From ddc304ae8222b3bc594299550e6783e28e77a6a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krist=C3=B3f=20T=C3=B3th?= Date: Tue, 2 Apr 2019 16:55:19 +0200 Subject: [PATCH] Rework test cases to comply new API --- tests.py | 145 ++++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 106 insertions(+), 39 deletions(-) diff --git a/tests.py b/tests.py index a0af73f..f4eea50 100644 --- a/tests.py +++ b/tests.py @@ -4,34 +4,112 @@ from os.path import exists, dirname, realpath, join from stat import S_ISFIFO from secrets import token_urlsafe from random import randint +from threading import Thread +from contextlib import contextmanager import pytest from echo_server import EchoPipeIOServer +BLOCK_TIMEOUT = 1 +SHUTDOWN_FAILURE_MSG = 'PipeIOServer failed to shut down!' + + @pytest.fixture -def pipe_io(): +def io_pipes(): + with pipe_io_server() as pipe_io: + with IOPipes(pipe_io.in_pipe, pipe_io.out_pipe) as io_pipes: + yield io_pipes + + +@contextmanager +def pipe_io_server(pipe_io_server_type=EchoPipeIOServer): + pipe_io_server = build_pipe_io_server(pipe_io_server_type) + thread = Thread(target=pipe_io_server.run) + thread.start() + yield pipe_io_server + pipe_io_server.stop() + thread.join(timeout=BLOCK_TIMEOUT) + if thread.is_alive(): + raise RuntimeError(SHUTDOWN_FAILURE_MSG) + + +def raise_if_thread_blocks(thread_target_function, unblock_function, timeout=BLOCK_TIMEOUT): + thread = Thread(target=thread_target_function) + thread.start() + unblock_function() + thread.join(timeout=timeout) + if thread.is_alive(): + raise RuntimeError(SHUTDOWN_FAILURE_MSG) + + +def build_pipe_io_server(pipe_io_server_type=EchoPipeIOServer): here = dirname(realpath(__file__)) - pipe_server = EchoPipeIOServer( + return pipe_io_server_type( join(here, 'in_pipe_tests'), join(here, 'out_pipe_tests') ) - pipe_server.run() - yield pipe_server - pipe_server.stop() -def test_pipes_exist(pipe_io): - for path in (pipe_io.in_pipe, pipe_io.out_pipe): +class IOPipes: + def __init__(self, in_pipe_path, out_pipe_path): + self.in_pipe_path = in_pipe_path + self.out_pipe_path = out_pipe_path + + def __enter__(self): + # pylint: disable=attribute-defined-outside-init + self.in_pipe = open(self.in_pipe_path, 'wb') + self.out_pipe = open(self.out_pipe_path, 'rb') + return self + + def __exit__(self, type_, value, traceback): + self.close() + + def close(self): + self.in_pipe.close() + self.out_pipe.close() + + def send(self, message): + self.in_pipe.write(message + b'\n') + self.in_pipe.flush() + + def recv(self): + return self.out_pipe.readline().rstrip(b'\n') + + +def test_run_creates_pipes(io_pipes): + for path in (io_pipes.in_pipe_path, io_pipes.out_pipe_path): assert exists(path) - - -def test_pipes_isfifo(pipe_io): - for path in (pipe_io.in_pipe, pipe_io.out_pipe): assert S_ISFIFO(stat(path).st_mode) +def test_stop(): + pipe_io = build_pipe_io_server() + def open_close_in_pipe(): + pipe_io.stop() + raise_if_thread_blocks(pipe_io.run, open_close_in_pipe) + for path in (pipe_io.in_pipe, pipe_io.out_pipe): + assert not exists(path) + + +def test_eof_stop(): + pipe_io = build_pipe_io_server() + def open_close_in_pipe(): + open(pipe_io.in_pipe, 'wb').close() + raise_if_thread_blocks(pipe_io.run, open_close_in_pipe) + + +def test_out_pipe_closed_stop(): + pipe_io = build_pipe_io_server() + def close_out_pipe_and_write(): + in_pipe = open(pipe_io.in_pipe, 'wb') + open(pipe_io.out_pipe, 'rb').close() + in_pipe.write(b'lel\n') + in_pipe.flush() + raise_if_thread_blocks(pipe_io.run, close_out_pipe_and_write) + + @pytest.mark.parametrize( 'test_data', [ 'Cats and cheese', @@ -44,22 +122,9 @@ def test_pipes_isfifo(pipe_io): token_urlsafe(32) ] ) -def test_io(pipe_io, test_data): - File(pipe_io.in_pipe).write(test_data.encode()) - assert File(pipe_io.out_pipe).read().decode() == test_data - - -class File: - def __init__(self, path): - self.path = path - - def write(self, what): - with open(self.path, 'wb') as ofile: - ofile.write(what) - - def read(self): - with open(self.path, 'rb') as ifile: - return ifile.read() +def test_io(io_pipes, test_data): + io_pipes.send(test_data.encode()) + assert io_pipes.recv().decode() == test_data @pytest.mark.parametrize( @@ -73,20 +138,22 @@ class File: 32*1024*1024 ] ) -def test_io_large_data(pipe_io, test_data_size): - test_data = urandom(test_data_size) - File(pipe_io.in_pipe).write(test_data) - assert File(pipe_io.out_pipe).read() == test_data +def test_io_large_data(io_pipes, test_data_size): + test_data = urandom(test_data_size).replace(b'\n', b'') + io_pipes.send(test_data) + received_data = io_pipes.recv() + assert received_data == test_data -def test_io_stress(pipe_io): +def test_io_stress(io_pipes): for _ in range(2222): - test_data = urandom(randint(1, 1024)) - File(pipe_io.in_pipe).write(test_data) - assert File(pipe_io.out_pipe).read() == test_data + test_data = urandom(randint(1, 1024)).replace(b'\n', b'') + io_pipes.send(test_data) + assert io_pipes.recv() == test_data -def test_stop_removes_pipes(pipe_io): - pipe_io.stop() - for path in (pipe_io.in_pipe, pipe_io.out_pipe): - assert not exists(path) +def test_io_newlines(io_pipes): + times = randint(1, 512) + io_pipes.send(b'\n' * times) + for _ in range(times + 1): # IOPipes.send appends +1 + assert io_pipes.recv() == b''