qemu

Форк
0
/
protocol.py 
596 строк · 18.2 Кб
1
import asyncio
2
from contextlib import contextmanager
3
import os
4
import socket
5
from tempfile import TemporaryDirectory
6

7
import avocado
8

9
from qemu.qmp import ConnectError, Runstate
10
from qemu.qmp.protocol import AsyncProtocol, StateError
11
from qemu.qmp.util import asyncio_run, create_task
12

13

14
class NullProtocol(AsyncProtocol[None]):
15
    """
16
    NullProtocol is a test mockup of an AsyncProtocol implementation.
17

18
    It adds a fake_session instance variable that enables a code path
19
    that bypasses the actual connection logic, but still allows the
20
    reader/writers to start.
21

22
    Because the message type is defined as None, an asyncio.Event named
23
    'trigger_input' is created that prohibits the reader from
24
    incessantly being able to yield None; this event can be poked to
25
    simulate an incoming message.
26

27
    For testing symmetry with do_recv, an interface is added to "send" a
28
    Null message.
29

30
    For testing purposes, a "simulate_disconnection" method is also
31
    added which allows us to trigger a bottom half disconnect without
32
    injecting any real errors into the reader/writer loops; in essence
33
    it performs exactly half of what disconnect() normally does.
34
    """
35
    def __init__(self, name=None):
36
        self.fake_session = False
37
        self.trigger_input: asyncio.Event
38
        super().__init__(name)
39

40
    async def _establish_session(self):
41
        self.trigger_input = asyncio.Event()
42
        await super()._establish_session()
43

44
    async def _do_start_server(self, address, ssl=None):
45
        if self.fake_session:
46
            self._accepted = asyncio.Event()
47
            self._set_state(Runstate.CONNECTING)
48
            await asyncio.sleep(0)
49
        else:
50
            await super()._do_start_server(address, ssl)
51

52
    async def _do_accept(self):
53
        if self.fake_session:
54
            self._accepted = None
55
        else:
56
            await super()._do_accept()
57

58
    async def _do_connect(self, address, ssl=None):
59
        if self.fake_session:
60
            self._set_state(Runstate.CONNECTING)
61
            await asyncio.sleep(0)
62
        else:
63
            await super()._do_connect(address, ssl)
64

65
    async def _do_recv(self) -> None:
66
        await self.trigger_input.wait()
67
        self.trigger_input.clear()
68

69
    def _do_send(self, msg: None) -> None:
70
        pass
71

72
    async def send_msg(self) -> None:
73
        await self._outgoing.put(None)
74

75
    async def simulate_disconnect(self) -> None:
76
        """
77
        Simulates a bottom-half disconnect.
78

79
        This method schedules a disconnection but does not wait for it
80
        to complete. This is used to put the loop into the DISCONNECTING
81
        state without fully quiescing it back to IDLE. This is normally
82
        something you cannot coax AsyncProtocol to do on purpose, but it
83
        will be similar to what happens with an unhandled Exception in
84
        the reader/writer.
85

86
        Under normal circumstances, the library design requires you to
87
        await on disconnect(), which awaits the disconnect task and
88
        returns bottom half errors as a pre-condition to allowing the
89
        loop to return back to IDLE.
90
        """
91
        self._schedule_disconnect()
92

93

94
class LineProtocol(AsyncProtocol[str]):
95
    def __init__(self, name=None):
96
        super().__init__(name)
97
        self.rx_history = []
98

99
    async def _do_recv(self) -> str:
100
        raw = await self._readline()
101
        msg = raw.decode()
102
        self.rx_history.append(msg)
103
        return msg
104

105
    def _do_send(self, msg: str) -> None:
106
        assert self._writer is not None
107
        self._writer.write(msg.encode() + b'\n')
108

109
    async def send_msg(self, msg: str) -> None:
110
        await self._outgoing.put(msg)
111

112

113
def run_as_task(coro, allow_cancellation=False):
114
    """
115
    Run a given coroutine as a task.
116

117
    Optionally, wrap it in a try..except block that allows this
118
    coroutine to be canceled gracefully.
119
    """
120
    async def _runner():
121
        try:
122
            await coro
123
        except asyncio.CancelledError:
124
            if allow_cancellation:
125
                return
126
            raise
127
    return create_task(_runner())
128

129

130
@contextmanager
131
def jammed_socket():
132
    """
133
    Opens up a random unused TCP port on localhost, then jams it.
134
    """
135
    socks = []
136

137
    try:
138
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
139
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
140
        sock.bind(('127.0.0.1', 0))
141
        sock.listen(1)
142
        address = sock.getsockname()
143

144
        socks.append(sock)
145

146
        # I don't *fully* understand why, but it takes *two* un-accepted
147
        # connections to start jamming the socket.
148
        for _ in range(2):
149
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
150
            sock.connect(address)
151
            socks.append(sock)
152

153
        yield address
154

155
    finally:
156
        for sock in socks:
157
            sock.close()
158

159

160
class Smoke(avocado.Test):
161

162
    def setUp(self):
163
        self.proto = NullProtocol()
164

165
    def test__repr__(self):
166
        self.assertEqual(
167
            repr(self.proto),
168
            "<NullProtocol runstate=IDLE>"
169
        )
170

171
    def testRunstate(self):
172
        self.assertEqual(
173
            self.proto.runstate,
174
            Runstate.IDLE
175
        )
176

177
    def testDefaultName(self):
178
        self.assertEqual(
179
            self.proto.name,
180
            None
181
        )
182

183
    def testLogger(self):
184
        self.assertEqual(
185
            self.proto.logger.name,
186
            'qemu.qmp.protocol'
187
        )
188

189
    def testName(self):
190
        self.proto = NullProtocol('Steve')
191

192
        self.assertEqual(
193
            self.proto.name,
194
            'Steve'
195
        )
196

197
        self.assertEqual(
198
            self.proto.logger.name,
199
            'qemu.qmp.protocol.Steve'
200
        )
201

202
        self.assertEqual(
203
            repr(self.proto),
204
            "<NullProtocol name='Steve' runstate=IDLE>"
205
        )
206

207

208
class TestBase(avocado.Test):
209

210
    def setUp(self):
211
        self.proto = NullProtocol(type(self).__name__)
212
        self.assertEqual(self.proto.runstate, Runstate.IDLE)
213
        self.runstate_watcher = None
214

215
    def tearDown(self):
216
        self.assertEqual(self.proto.runstate, Runstate.IDLE)
217

218
    async def _asyncSetUp(self):
219
        pass
220

221
    async def _asyncTearDown(self):
222
        if self.runstate_watcher:
223
            await self.runstate_watcher
224

225
    @staticmethod
226
    def async_test(async_test_method):
227
        """
228
        Decorator; adds SetUp and TearDown to async tests.
229
        """
230
        async def _wrapper(self, *args, **kwargs):
231
            loop = asyncio.get_event_loop()
232
            loop.set_debug(True)
233

234
            await self._asyncSetUp()
235
            await async_test_method(self, *args, **kwargs)
236
            await self._asyncTearDown()
237

238
        return _wrapper
239

240
    # Definitions
241

242
    # The states we expect a "bad" connect/accept attempt to transition through
243
    BAD_CONNECTION_STATES = (
244
        Runstate.CONNECTING,
245
        Runstate.DISCONNECTING,
246
        Runstate.IDLE,
247
    )
248

249
    # The states we expect a "good" session to transition through
250
    GOOD_CONNECTION_STATES = (
251
        Runstate.CONNECTING,
252
        Runstate.RUNNING,
253
        Runstate.DISCONNECTING,
254
        Runstate.IDLE,
255
    )
256

257
    # Helpers
258

259
    async def _watch_runstates(self, *states):
260
        """
261
        This launches a task alongside (most) tests below to confirm that
262
        the sequence of runstate changes that occur is exactly as
263
        anticipated.
264
        """
265
        async def _watcher():
266
            for state in states:
267
                new_state = await self.proto.runstate_changed()
268
                self.assertEqual(
269
                    new_state,
270
                    state,
271
                    msg=f"Expected state '{state.name}'",
272
                )
273

274
        self.runstate_watcher = create_task(_watcher())
275
        # Kick the loop and force the task to block on the event.
276
        await asyncio.sleep(0)
277

278

279
class State(TestBase):
280

281
    @TestBase.async_test
282
    async def testSuperfluousDisconnect(self):
283
        """
284
        Test calling disconnect() while already disconnected.
285
        """
286
        await self._watch_runstates(
287
            Runstate.DISCONNECTING,
288
            Runstate.IDLE,
289
        )
290
        await self.proto.disconnect()
291

292

293
class Connect(TestBase):
294
    """
295
    Tests primarily related to calling Connect().
296
    """
297
    async def _bad_connection(self, family: str):
298
        assert family in ('INET', 'UNIX')
299

300
        if family == 'INET':
301
            await self.proto.connect(('127.0.0.1', 0))
302
        elif family == 'UNIX':
303
            await self.proto.connect('/dev/null')
304

305
    async def _hanging_connection(self):
306
        with jammed_socket() as addr:
307
            await self.proto.connect(addr)
308

309
    async def _bad_connection_test(self, family: str):
310
        await self._watch_runstates(*self.BAD_CONNECTION_STATES)
311

312
        with self.assertRaises(ConnectError) as context:
313
            await self._bad_connection(family)
314

315
        self.assertIsInstance(context.exception.exc, OSError)
316
        self.assertEqual(
317
            context.exception.error_message,
318
            "Failed to establish connection"
319
        )
320

321
    @TestBase.async_test
322
    async def testBadINET(self):
323
        """
324
        Test an immediately rejected call to an IP target.
325
        """
326
        await self._bad_connection_test('INET')
327

328
    @TestBase.async_test
329
    async def testBadUNIX(self):
330
        """
331
        Test an immediately rejected call to a UNIX socket target.
332
        """
333
        await self._bad_connection_test('UNIX')
334

335
    @TestBase.async_test
336
    async def testCancellation(self):
337
        """
338
        Test what happens when a connection attempt is aborted.
339
        """
340
        # Note that accept() cannot be cancelled outright, as it isn't a task.
341
        # However, we can wrap it in a task and cancel *that*.
342
        await self._watch_runstates(*self.BAD_CONNECTION_STATES)
343
        task = run_as_task(self._hanging_connection(), allow_cancellation=True)
344

345
        state = await self.proto.runstate_changed()
346
        self.assertEqual(state, Runstate.CONNECTING)
347

348
        # This is insider baseball, but the connection attempt has
349
        # yielded *just* before the actual connection attempt, so kick
350
        # the loop to make sure it's truly wedged.
351
        await asyncio.sleep(0)
352

353
        task.cancel()
354
        await task
355

356
    @TestBase.async_test
357
    async def testTimeout(self):
358
        """
359
        Test what happens when a connection attempt times out.
360
        """
361
        await self._watch_runstates(*self.BAD_CONNECTION_STATES)
362
        task = run_as_task(self._hanging_connection())
363

364
        # More insider baseball: to improve the speed of this test while
365
        # guaranteeing that the connection even gets a chance to start,
366
        # verify that the connection hangs *first*, then await the
367
        # result of the task with a nearly-zero timeout.
368

369
        state = await self.proto.runstate_changed()
370
        self.assertEqual(state, Runstate.CONNECTING)
371
        await asyncio.sleep(0)
372

373
        with self.assertRaises(asyncio.TimeoutError):
374
            await asyncio.wait_for(task, timeout=0)
375

376
    @TestBase.async_test
377
    async def testRequire(self):
378
        """
379
        Test what happens when a connection attempt is made while CONNECTING.
380
        """
381
        await self._watch_runstates(*self.BAD_CONNECTION_STATES)
382
        task = run_as_task(self._hanging_connection(), allow_cancellation=True)
383

384
        state = await self.proto.runstate_changed()
385
        self.assertEqual(state, Runstate.CONNECTING)
386

387
        with self.assertRaises(StateError) as context:
388
            await self._bad_connection('UNIX')
389

390
        self.assertEqual(
391
            context.exception.error_message,
392
            "NullProtocol is currently connecting."
393
        )
394
        self.assertEqual(context.exception.state, Runstate.CONNECTING)
395
        self.assertEqual(context.exception.required, Runstate.IDLE)
396

397
        task.cancel()
398
        await task
399

400
    @TestBase.async_test
401
    async def testImplicitRunstateInit(self):
402
        """
403
        Test what happens if we do not wait on the runstate event until
404
        AFTER a connection is made, i.e., connect()/accept() themselves
405
        initialize the runstate event. All of the above tests force the
406
        initialization by waiting on the runstate *first*.
407
        """
408
        task = run_as_task(self._hanging_connection(), allow_cancellation=True)
409

410
        # Kick the loop to coerce the state change
411
        await asyncio.sleep(0)
412
        assert self.proto.runstate == Runstate.CONNECTING
413

414
        # We already missed the transition to CONNECTING
415
        await self._watch_runstates(Runstate.DISCONNECTING, Runstate.IDLE)
416

417
        task.cancel()
418
        await task
419

420

421
class Accept(Connect):
422
    """
423
    All of the same tests as Connect, but using the accept() interface.
424
    """
425
    async def _bad_connection(self, family: str):
426
        assert family in ('INET', 'UNIX')
427

428
        if family == 'INET':
429
            await self.proto.start_server_and_accept(('example.com', 1))
430
        elif family == 'UNIX':
431
            await self.proto.start_server_and_accept('/dev/null')
432

433
    async def _hanging_connection(self):
434
        with TemporaryDirectory(suffix='.qmp') as tmpdir:
435
            sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
436
            await self.proto.start_server_and_accept(sock)
437

438

439
class FakeSession(TestBase):
440

441
    def setUp(self):
442
        super().setUp()
443
        self.proto.fake_session = True
444

445
    async def _asyncSetUp(self):
446
        await super()._asyncSetUp()
447
        await self._watch_runstates(*self.GOOD_CONNECTION_STATES)
448

449
    async def _asyncTearDown(self):
450
        await self.proto.disconnect()
451
        await super()._asyncTearDown()
452

453
    ####
454

455
    @TestBase.async_test
456
    async def testFakeConnect(self):
457

458
        """Test the full state lifecycle (via connect) with a no-op session."""
459
        await self.proto.connect('/not/a/real/path')
460
        self.assertEqual(self.proto.runstate, Runstate.RUNNING)
461

462
    @TestBase.async_test
463
    async def testFakeAccept(self):
464
        """Test the full state lifecycle (via accept) with a no-op session."""
465
        await self.proto.start_server_and_accept('/not/a/real/path')
466
        self.assertEqual(self.proto.runstate, Runstate.RUNNING)
467

468
    @TestBase.async_test
469
    async def testFakeRecv(self):
470
        """Test receiving a fake/null message."""
471
        await self.proto.start_server_and_accept('/not/a/real/path')
472

473
        logname = self.proto.logger.name
474
        with self.assertLogs(logname, level='DEBUG') as context:
475
            self.proto.trigger_input.set()
476
            self.proto.trigger_input.clear()
477
            await asyncio.sleep(0)  # Kick reader.
478

479
        self.assertEqual(
480
            context.output,
481
            [f"DEBUG:{logname}:<-- None"],
482
        )
483

484
    @TestBase.async_test
485
    async def testFakeSend(self):
486
        """Test sending a fake/null message."""
487
        await self.proto.start_server_and_accept('/not/a/real/path')
488

489
        logname = self.proto.logger.name
490
        with self.assertLogs(logname, level='DEBUG') as context:
491
            # Cheat: Send a Null message to nobody.
492
            await self.proto.send_msg()
493
            # Kick writer; awaiting on a queue.put isn't sufficient to yield.
494
            await asyncio.sleep(0)
495

496
        self.assertEqual(
497
            context.output,
498
            [f"DEBUG:{logname}:--> None"],
499
        )
500

501
    async def _prod_session_api(
502
            self,
503
            current_state: Runstate,
504
            error_message: str,
505
            accept: bool = True
506
    ):
507
        with self.assertRaises(StateError) as context:
508
            if accept:
509
                await self.proto.start_server_and_accept('/not/a/real/path')
510
            else:
511
                await self.proto.connect('/not/a/real/path')
512

513
        self.assertEqual(context.exception.error_message, error_message)
514
        self.assertEqual(context.exception.state, current_state)
515
        self.assertEqual(context.exception.required, Runstate.IDLE)
516

517
    @TestBase.async_test
518
    async def testAcceptRequireRunning(self):
519
        """Test that accept() cannot be called when Runstate=RUNNING"""
520
        await self.proto.start_server_and_accept('/not/a/real/path')
521

522
        await self._prod_session_api(
523
            Runstate.RUNNING,
524
            "NullProtocol is already connected and running.",
525
            accept=True,
526
        )
527

528
    @TestBase.async_test
529
    async def testConnectRequireRunning(self):
530
        """Test that connect() cannot be called when Runstate=RUNNING"""
531
        await self.proto.start_server_and_accept('/not/a/real/path')
532

533
        await self._prod_session_api(
534
            Runstate.RUNNING,
535
            "NullProtocol is already connected and running.",
536
            accept=False,
537
        )
538

539
    @TestBase.async_test
540
    async def testAcceptRequireDisconnecting(self):
541
        """Test that accept() cannot be called when Runstate=DISCONNECTING"""
542
        await self.proto.start_server_and_accept('/not/a/real/path')
543

544
        # Cheat: force a disconnect.
545
        await self.proto.simulate_disconnect()
546

547
        await self._prod_session_api(
548
            Runstate.DISCONNECTING,
549
            ("NullProtocol is disconnecting."
550
             " Call disconnect() to return to IDLE state."),
551
            accept=True,
552
        )
553

554
    @TestBase.async_test
555
    async def testConnectRequireDisconnecting(self):
556
        """Test that connect() cannot be called when Runstate=DISCONNECTING"""
557
        await self.proto.start_server_and_accept('/not/a/real/path')
558

559
        # Cheat: force a disconnect.
560
        await self.proto.simulate_disconnect()
561

562
        await self._prod_session_api(
563
            Runstate.DISCONNECTING,
564
            ("NullProtocol is disconnecting."
565
             " Call disconnect() to return to IDLE state."),
566
            accept=False,
567
        )
568

569

570
class SimpleSession(TestBase):
571

572
    def setUp(self):
573
        super().setUp()
574
        self.server = LineProtocol(type(self).__name__ + '-server')
575

576
    async def _asyncSetUp(self):
577
        await super()._asyncSetUp()
578
        await self._watch_runstates(*self.GOOD_CONNECTION_STATES)
579

580
    async def _asyncTearDown(self):
581
        await self.proto.disconnect()
582
        try:
583
            await self.server.disconnect()
584
        except EOFError:
585
            pass
586
        await super()._asyncTearDown()
587

588
    @TestBase.async_test
589
    async def testSmoke(self):
590
        with TemporaryDirectory(suffix='.qmp') as tmpdir:
591
            sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
592
            server_task = create_task(self.server.start_server_and_accept(sock))
593

594
            # give the server a chance to start listening [...]
595
            await asyncio.sleep(0)
596
            await self.proto.connect(sock)
597

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.