 8193b9d148
			
		
	
	
		8193b9d148
		
	
	
	
	
		
			
			Tests a real connect, a real accept, and really sending and receiving a message over a UNIX socket. Brings coverage of protocol.py up to ~93%. Signed-off-by: John Snow <jsnow@redhat.com> Message-id: 20210915162955.333025-27-jsnow@redhat.com Signed-off-by: John Snow <jsnow@redhat.com>
		
			
				
	
	
		
			584 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			584 lines
		
	
	
		
			18 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import asyncio
 | |
| from contextlib import contextmanager
 | |
| import os
 | |
| import socket
 | |
| from tempfile import TemporaryDirectory
 | |
| 
 | |
| import avocado
 | |
| 
 | |
| from qemu.aqmp import ConnectError, Runstate
 | |
| from qemu.aqmp.protocol import AsyncProtocol, StateError
 | |
| from qemu.aqmp.util import asyncio_run, create_task
 | |
| 
 | |
| 
 | |
| class NullProtocol(AsyncProtocol[None]):
 | |
|     """
 | |
|     NullProtocol is a test mockup of an AsyncProtocol implementation.
 | |
| 
 | |
|     It adds a fake_session instance variable that enables a code path
 | |
|     that bypasses the actual connection logic, but still allows the
 | |
|     reader/writers to start.
 | |
| 
 | |
|     Because the message type is defined as None, an asyncio.Event named
 | |
|     'trigger_input' is created that prohibits the reader from
 | |
|     incessantly being able to yield None; this event can be poked to
 | |
|     simulate an incoming message.
 | |
| 
 | |
|     For testing symmetry with do_recv, an interface is added to "send" a
 | |
|     Null message.
 | |
| 
 | |
|     For testing purposes, a "simulate_disconnection" method is also
 | |
|     added which allows us to trigger a bottom half disconnect without
 | |
|     injecting any real errors into the reader/writer loops; in essence
 | |
|     it performs exactly half of what disconnect() normally does.
 | |
|     """
 | |
|     def __init__(self, name=None):
 | |
|         self.fake_session = False
 | |
|         self.trigger_input: asyncio.Event
 | |
|         super().__init__(name)
 | |
| 
 | |
|     async def _establish_session(self):
 | |
|         self.trigger_input = asyncio.Event()
 | |
|         await super()._establish_session()
 | |
| 
 | |
|     async def _do_accept(self, address, ssl=None):
 | |
|         if not self.fake_session:
 | |
|             await super()._do_accept(address, ssl)
 | |
| 
 | |
|     async def _do_connect(self, address, ssl=None):
 | |
|         if not self.fake_session:
 | |
|             await super()._do_connect(address, ssl)
 | |
| 
 | |
|     async def _do_recv(self) -> None:
 | |
|         await self.trigger_input.wait()
 | |
|         self.trigger_input.clear()
 | |
| 
 | |
|     def _do_send(self, msg: None) -> None:
 | |
|         pass
 | |
| 
 | |
|     async def send_msg(self) -> None:
 | |
|         await self._outgoing.put(None)
 | |
| 
 | |
|     async def simulate_disconnect(self) -> None:
 | |
|         """
 | |
|         Simulates a bottom-half disconnect.
 | |
| 
 | |
|         This method schedules a disconnection but does not wait for it
 | |
|         to complete. This is used to put the loop into the DISCONNECTING
 | |
|         state without fully quiescing it back to IDLE. This is normally
 | |
|         something you cannot coax AsyncProtocol to do on purpose, but it
 | |
|         will be similar to what happens with an unhandled Exception in
 | |
|         the reader/writer.
 | |
| 
 | |
|         Under normal circumstances, the library design requires you to
 | |
|         await on disconnect(), which awaits the disconnect task and
 | |
|         returns bottom half errors as a pre-condition to allowing the
 | |
|         loop to return back to IDLE.
 | |
|         """
 | |
|         self._schedule_disconnect()
 | |
| 
 | |
| 
 | |
| class LineProtocol(AsyncProtocol[str]):
 | |
|     def __init__(self, name=None):
 | |
|         super().__init__(name)
 | |
|         self.rx_history = []
 | |
| 
 | |
|     async def _do_recv(self) -> str:
 | |
|         raw = await self._readline()
 | |
|         msg = raw.decode()
 | |
|         self.rx_history.append(msg)
 | |
|         return msg
 | |
| 
 | |
|     def _do_send(self, msg: str) -> None:
 | |
|         assert self._writer is not None
 | |
|         self._writer.write(msg.encode() + b'\n')
 | |
| 
 | |
|     async def send_msg(self, msg: str) -> None:
 | |
|         await self._outgoing.put(msg)
 | |
| 
 | |
| 
 | |
| def run_as_task(coro, allow_cancellation=False):
 | |
|     """
 | |
|     Run a given coroutine as a task.
 | |
| 
 | |
|     Optionally, wrap it in a try..except block that allows this
 | |
|     coroutine to be canceled gracefully.
 | |
|     """
 | |
|     async def _runner():
 | |
|         try:
 | |
|             await coro
 | |
|         except asyncio.CancelledError:
 | |
|             if allow_cancellation:
 | |
|                 return
 | |
|             raise
 | |
|     return create_task(_runner())
 | |
| 
 | |
| 
 | |
| @contextmanager
 | |
| def jammed_socket():
 | |
|     """
 | |
|     Opens up a random unused TCP port on localhost, then jams it.
 | |
|     """
 | |
|     socks = []
 | |
| 
 | |
|     try:
 | |
|         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 | |
|         sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
 | |
|         sock.bind(('127.0.0.1', 0))
 | |
|         sock.listen(1)
 | |
|         address = sock.getsockname()
 | |
| 
 | |
|         socks.append(sock)
 | |
| 
 | |
|         # I don't *fully* understand why, but it takes *two* un-accepted
 | |
|         # connections to start jamming the socket.
 | |
|         for _ in range(2):
 | |
|             sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 | |
|             sock.connect(address)
 | |
|             socks.append(sock)
 | |
| 
 | |
|         yield address
 | |
| 
 | |
|     finally:
 | |
|         for sock in socks:
 | |
|             sock.close()
 | |
| 
 | |
| 
 | |
| class Smoke(avocado.Test):
 | |
| 
 | |
|     def setUp(self):
 | |
|         self.proto = NullProtocol()
 | |
| 
 | |
|     def test__repr__(self):
 | |
|         self.assertEqual(
 | |
|             repr(self.proto),
 | |
|             "<NullProtocol runstate=IDLE>"
 | |
|         )
 | |
| 
 | |
|     def testRunstate(self):
 | |
|         self.assertEqual(
 | |
|             self.proto.runstate,
 | |
|             Runstate.IDLE
 | |
|         )
 | |
| 
 | |
|     def testDefaultName(self):
 | |
|         self.assertEqual(
 | |
|             self.proto.name,
 | |
|             None
 | |
|         )
 | |
| 
 | |
|     def testLogger(self):
 | |
|         self.assertEqual(
 | |
|             self.proto.logger.name,
 | |
|             'qemu.aqmp.protocol'
 | |
|         )
 | |
| 
 | |
|     def testName(self):
 | |
|         self.proto = NullProtocol('Steve')
 | |
| 
 | |
|         self.assertEqual(
 | |
|             self.proto.name,
 | |
|             'Steve'
 | |
|         )
 | |
| 
 | |
|         self.assertEqual(
 | |
|             self.proto.logger.name,
 | |
|             'qemu.aqmp.protocol.Steve'
 | |
|         )
 | |
| 
 | |
|         self.assertEqual(
 | |
|             repr(self.proto),
 | |
|             "<NullProtocol name='Steve' runstate=IDLE>"
 | |
|         )
 | |
| 
 | |
| 
 | |
| class TestBase(avocado.Test):
 | |
| 
 | |
|     def setUp(self):
 | |
|         self.proto = NullProtocol(type(self).__name__)
 | |
|         self.assertEqual(self.proto.runstate, Runstate.IDLE)
 | |
|         self.runstate_watcher = None
 | |
| 
 | |
|     def tearDown(self):
 | |
|         self.assertEqual(self.proto.runstate, Runstate.IDLE)
 | |
| 
 | |
|     async def _asyncSetUp(self):
 | |
|         pass
 | |
| 
 | |
|     async def _asyncTearDown(self):
 | |
|         if self.runstate_watcher:
 | |
|             await self.runstate_watcher
 | |
| 
 | |
|     @staticmethod
 | |
|     def async_test(async_test_method):
 | |
|         """
 | |
|         Decorator; adds SetUp and TearDown to async tests.
 | |
|         """
 | |
|         async def _wrapper(self, *args, **kwargs):
 | |
|             loop = asyncio.get_event_loop()
 | |
|             loop.set_debug(True)
 | |
| 
 | |
|             await self._asyncSetUp()
 | |
|             await async_test_method(self, *args, **kwargs)
 | |
|             await self._asyncTearDown()
 | |
| 
 | |
|         return _wrapper
 | |
| 
 | |
|     # Definitions
 | |
| 
 | |
|     # The states we expect a "bad" connect/accept attempt to transition through
 | |
|     BAD_CONNECTION_STATES = (
 | |
|         Runstate.CONNECTING,
 | |
|         Runstate.DISCONNECTING,
 | |
|         Runstate.IDLE,
 | |
|     )
 | |
| 
 | |
|     # The states we expect a "good" session to transition through
 | |
|     GOOD_CONNECTION_STATES = (
 | |
|         Runstate.CONNECTING,
 | |
|         Runstate.RUNNING,
 | |
|         Runstate.DISCONNECTING,
 | |
|         Runstate.IDLE,
 | |
|     )
 | |
| 
 | |
|     # Helpers
 | |
| 
 | |
|     async def _watch_runstates(self, *states):
 | |
|         """
 | |
|         This launches a task alongside (most) tests below to confirm that
 | |
|         the sequence of runstate changes that occur is exactly as
 | |
|         anticipated.
 | |
|         """
 | |
|         async def _watcher():
 | |
|             for state in states:
 | |
|                 new_state = await self.proto.runstate_changed()
 | |
|                 self.assertEqual(
 | |
|                     new_state,
 | |
|                     state,
 | |
|                     msg=f"Expected state '{state.name}'",
 | |
|                 )
 | |
| 
 | |
|         self.runstate_watcher = create_task(_watcher())
 | |
|         # Kick the loop and force the task to block on the event.
 | |
|         await asyncio.sleep(0)
 | |
| 
 | |
| 
 | |
| class State(TestBase):
 | |
| 
 | |
|     @TestBase.async_test
 | |
|     async def testSuperfluousDisconnect(self):
 | |
|         """
 | |
|         Test calling disconnect() while already disconnected.
 | |
|         """
 | |
|         await self._watch_runstates(
 | |
|             Runstate.DISCONNECTING,
 | |
|             Runstate.IDLE,
 | |
|         )
 | |
|         await self.proto.disconnect()
 | |
| 
 | |
| 
 | |
| class Connect(TestBase):
 | |
|     """
 | |
|     Tests primarily related to calling Connect().
 | |
|     """
 | |
|     async def _bad_connection(self, family: str):
 | |
|         assert family in ('INET', 'UNIX')
 | |
| 
 | |
|         if family == 'INET':
 | |
|             await self.proto.connect(('127.0.0.1', 0))
 | |
|         elif family == 'UNIX':
 | |
|             await self.proto.connect('/dev/null')
 | |
| 
 | |
|     async def _hanging_connection(self):
 | |
|         with jammed_socket() as addr:
 | |
|             await self.proto.connect(addr)
 | |
| 
 | |
|     async def _bad_connection_test(self, family: str):
 | |
|         await self._watch_runstates(*self.BAD_CONNECTION_STATES)
 | |
| 
 | |
|         with self.assertRaises(ConnectError) as context:
 | |
|             await self._bad_connection(family)
 | |
| 
 | |
|         self.assertIsInstance(context.exception.exc, OSError)
 | |
|         self.assertEqual(
 | |
|             context.exception.error_message,
 | |
|             "Failed to establish connection"
 | |
|         )
 | |
| 
 | |
|     @TestBase.async_test
 | |
|     async def testBadINET(self):
 | |
|         """
 | |
|         Test an immediately rejected call to an IP target.
 | |
|         """
 | |
|         await self._bad_connection_test('INET')
 | |
| 
 | |
|     @TestBase.async_test
 | |
|     async def testBadUNIX(self):
 | |
|         """
 | |
|         Test an immediately rejected call to a UNIX socket target.
 | |
|         """
 | |
|         await self._bad_connection_test('UNIX')
 | |
| 
 | |
|     @TestBase.async_test
 | |
|     async def testCancellation(self):
 | |
|         """
 | |
|         Test what happens when a connection attempt is aborted.
 | |
|         """
 | |
|         # Note that accept() cannot be cancelled outright, as it isn't a task.
 | |
|         # However, we can wrap it in a task and cancel *that*.
 | |
|         await self._watch_runstates(*self.BAD_CONNECTION_STATES)
 | |
|         task = run_as_task(self._hanging_connection(), allow_cancellation=True)
 | |
| 
 | |
|         state = await self.proto.runstate_changed()
 | |
|         self.assertEqual(state, Runstate.CONNECTING)
 | |
| 
 | |
|         # This is insider baseball, but the connection attempt has
 | |
|         # yielded *just* before the actual connection attempt, so kick
 | |
|         # the loop to make sure it's truly wedged.
 | |
|         await asyncio.sleep(0)
 | |
| 
 | |
|         task.cancel()
 | |
|         await task
 | |
| 
 | |
|     @TestBase.async_test
 | |
|     async def testTimeout(self):
 | |
|         """
 | |
|         Test what happens when a connection attempt times out.
 | |
|         """
 | |
|         await self._watch_runstates(*self.BAD_CONNECTION_STATES)
 | |
|         task = run_as_task(self._hanging_connection())
 | |
| 
 | |
|         # More insider baseball: to improve the speed of this test while
 | |
|         # guaranteeing that the connection even gets a chance to start,
 | |
|         # verify that the connection hangs *first*, then await the
 | |
|         # result of the task with a nearly-zero timeout.
 | |
| 
 | |
|         state = await self.proto.runstate_changed()
 | |
|         self.assertEqual(state, Runstate.CONNECTING)
 | |
|         await asyncio.sleep(0)
 | |
| 
 | |
|         with self.assertRaises(asyncio.TimeoutError):
 | |
|             await asyncio.wait_for(task, timeout=0)
 | |
| 
 | |
|     @TestBase.async_test
 | |
|     async def testRequire(self):
 | |
|         """
 | |
|         Test what happens when a connection attempt is made while CONNECTING.
 | |
|         """
 | |
|         await self._watch_runstates(*self.BAD_CONNECTION_STATES)
 | |
|         task = run_as_task(self._hanging_connection(), allow_cancellation=True)
 | |
| 
 | |
|         state = await self.proto.runstate_changed()
 | |
|         self.assertEqual(state, Runstate.CONNECTING)
 | |
| 
 | |
|         with self.assertRaises(StateError) as context:
 | |
|             await self._bad_connection('UNIX')
 | |
| 
 | |
|         self.assertEqual(
 | |
|             context.exception.error_message,
 | |
|             "NullProtocol is currently connecting."
 | |
|         )
 | |
|         self.assertEqual(context.exception.state, Runstate.CONNECTING)
 | |
|         self.assertEqual(context.exception.required, Runstate.IDLE)
 | |
| 
 | |
|         task.cancel()
 | |
|         await task
 | |
| 
 | |
|     @TestBase.async_test
 | |
|     async def testImplicitRunstateInit(self):
 | |
|         """
 | |
|         Test what happens if we do not wait on the runstate event until
 | |
|         AFTER a connection is made, i.e., connect()/accept() themselves
 | |
|         initialize the runstate event. All of the above tests force the
 | |
|         initialization by waiting on the runstate *first*.
 | |
|         """
 | |
|         task = run_as_task(self._hanging_connection(), allow_cancellation=True)
 | |
| 
 | |
|         # Kick the loop to coerce the state change
 | |
|         await asyncio.sleep(0)
 | |
|         assert self.proto.runstate == Runstate.CONNECTING
 | |
| 
 | |
|         # We already missed the transition to CONNECTING
 | |
|         await self._watch_runstates(Runstate.DISCONNECTING, Runstate.IDLE)
 | |
| 
 | |
|         task.cancel()
 | |
|         await task
 | |
| 
 | |
| 
 | |
| class Accept(Connect):
 | |
|     """
 | |
|     All of the same tests as Connect, but using the accept() interface.
 | |
|     """
 | |
|     async def _bad_connection(self, family: str):
 | |
|         assert family in ('INET', 'UNIX')
 | |
| 
 | |
|         if family == 'INET':
 | |
|             await self.proto.accept(('example.com', 1))
 | |
|         elif family == 'UNIX':
 | |
|             await self.proto.accept('/dev/null')
 | |
| 
 | |
|     async def _hanging_connection(self):
 | |
|         with TemporaryDirectory(suffix='.aqmp') as tmpdir:
 | |
|             sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
 | |
|             await self.proto.accept(sock)
 | |
| 
 | |
| 
 | |
| class FakeSession(TestBase):
 | |
| 
 | |
|     def setUp(self):
 | |
|         super().setUp()
 | |
|         self.proto.fake_session = True
 | |
| 
 | |
|     async def _asyncSetUp(self):
 | |
|         await super()._asyncSetUp()
 | |
|         await self._watch_runstates(*self.GOOD_CONNECTION_STATES)
 | |
| 
 | |
|     async def _asyncTearDown(self):
 | |
|         await self.proto.disconnect()
 | |
|         await super()._asyncTearDown()
 | |
| 
 | |
|     ####
 | |
| 
 | |
|     @TestBase.async_test
 | |
|     async def testFakeConnect(self):
 | |
| 
 | |
|         """Test the full state lifecycle (via connect) with a no-op session."""
 | |
|         await self.proto.connect('/not/a/real/path')
 | |
|         self.assertEqual(self.proto.runstate, Runstate.RUNNING)
 | |
| 
 | |
|     @TestBase.async_test
 | |
|     async def testFakeAccept(self):
 | |
|         """Test the full state lifecycle (via accept) with a no-op session."""
 | |
|         await self.proto.accept('/not/a/real/path')
 | |
|         self.assertEqual(self.proto.runstate, Runstate.RUNNING)
 | |
| 
 | |
|     @TestBase.async_test
 | |
|     async def testFakeRecv(self):
 | |
|         """Test receiving a fake/null message."""
 | |
|         await self.proto.accept('/not/a/real/path')
 | |
| 
 | |
|         logname = self.proto.logger.name
 | |
|         with self.assertLogs(logname, level='DEBUG') as context:
 | |
|             self.proto.trigger_input.set()
 | |
|             self.proto.trigger_input.clear()
 | |
|             await asyncio.sleep(0)  # Kick reader.
 | |
| 
 | |
|         self.assertEqual(
 | |
|             context.output,
 | |
|             [f"DEBUG:{logname}:<-- None"],
 | |
|         )
 | |
| 
 | |
|     @TestBase.async_test
 | |
|     async def testFakeSend(self):
 | |
|         """Test sending a fake/null message."""
 | |
|         await self.proto.accept('/not/a/real/path')
 | |
| 
 | |
|         logname = self.proto.logger.name
 | |
|         with self.assertLogs(logname, level='DEBUG') as context:
 | |
|             # Cheat: Send a Null message to nobody.
 | |
|             await self.proto.send_msg()
 | |
|             # Kick writer; awaiting on a queue.put isn't sufficient to yield.
 | |
|             await asyncio.sleep(0)
 | |
| 
 | |
|         self.assertEqual(
 | |
|             context.output,
 | |
|             [f"DEBUG:{logname}:--> None"],
 | |
|         )
 | |
| 
 | |
|     async def _prod_session_api(
 | |
|             self,
 | |
|             current_state: Runstate,
 | |
|             error_message: str,
 | |
|             accept: bool = True
 | |
|     ):
 | |
|         with self.assertRaises(StateError) as context:
 | |
|             if accept:
 | |
|                 await self.proto.accept('/not/a/real/path')
 | |
|             else:
 | |
|                 await self.proto.connect('/not/a/real/path')
 | |
| 
 | |
|         self.assertEqual(context.exception.error_message, error_message)
 | |
|         self.assertEqual(context.exception.state, current_state)
 | |
|         self.assertEqual(context.exception.required, Runstate.IDLE)
 | |
| 
 | |
|     @TestBase.async_test
 | |
|     async def testAcceptRequireRunning(self):
 | |
|         """Test that accept() cannot be called when Runstate=RUNNING"""
 | |
|         await self.proto.accept('/not/a/real/path')
 | |
| 
 | |
|         await self._prod_session_api(
 | |
|             Runstate.RUNNING,
 | |
|             "NullProtocol is already connected and running.",
 | |
|             accept=True,
 | |
|         )
 | |
| 
 | |
|     @TestBase.async_test
 | |
|     async def testConnectRequireRunning(self):
 | |
|         """Test that connect() cannot be called when Runstate=RUNNING"""
 | |
|         await self.proto.accept('/not/a/real/path')
 | |
| 
 | |
|         await self._prod_session_api(
 | |
|             Runstate.RUNNING,
 | |
|             "NullProtocol is already connected and running.",
 | |
|             accept=False,
 | |
|         )
 | |
| 
 | |
|     @TestBase.async_test
 | |
|     async def testAcceptRequireDisconnecting(self):
 | |
|         """Test that accept() cannot be called when Runstate=DISCONNECTING"""
 | |
|         await self.proto.accept('/not/a/real/path')
 | |
| 
 | |
|         # Cheat: force a disconnect.
 | |
|         await self.proto.simulate_disconnect()
 | |
| 
 | |
|         await self._prod_session_api(
 | |
|             Runstate.DISCONNECTING,
 | |
|             ("NullProtocol is disconnecting."
 | |
|              " Call disconnect() to return to IDLE state."),
 | |
|             accept=True,
 | |
|         )
 | |
| 
 | |
|     @TestBase.async_test
 | |
|     async def testConnectRequireDisconnecting(self):
 | |
|         """Test that connect() cannot be called when Runstate=DISCONNECTING"""
 | |
|         await self.proto.accept('/not/a/real/path')
 | |
| 
 | |
|         # Cheat: force a disconnect.
 | |
|         await self.proto.simulate_disconnect()
 | |
| 
 | |
|         await self._prod_session_api(
 | |
|             Runstate.DISCONNECTING,
 | |
|             ("NullProtocol is disconnecting."
 | |
|              " Call disconnect() to return to IDLE state."),
 | |
|             accept=False,
 | |
|         )
 | |
| 
 | |
| 
 | |
| class SimpleSession(TestBase):
 | |
| 
 | |
|     def setUp(self):
 | |
|         super().setUp()
 | |
|         self.server = LineProtocol(type(self).__name__ + '-server')
 | |
| 
 | |
|     async def _asyncSetUp(self):
 | |
|         await super()._asyncSetUp()
 | |
|         await self._watch_runstates(*self.GOOD_CONNECTION_STATES)
 | |
| 
 | |
|     async def _asyncTearDown(self):
 | |
|         await self.proto.disconnect()
 | |
|         try:
 | |
|             await self.server.disconnect()
 | |
|         except EOFError:
 | |
|             pass
 | |
|         await super()._asyncTearDown()
 | |
| 
 | |
|     @TestBase.async_test
 | |
|     async def testSmoke(self):
 | |
|         with TemporaryDirectory(suffix='.aqmp') as tmpdir:
 | |
|             sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
 | |
|             server_task = create_task(self.server.accept(sock))
 | |
| 
 | |
|             # give the server a chance to start listening [...]
 | |
|             await asyncio.sleep(0)
 | |
|             await self.proto.connect(sock)
 |