|
11 | 11 | import pytest |
12 | 12 |
|
13 | 13 | import aiohttp |
14 | | -from aiohttp import WSServerHandshakeError, web |
| 14 | +from aiohttp import WSServerHandshakeError, hdrs, web |
15 | 15 | from aiohttp.http import WSCloseCode, WSMsgType |
16 | 16 | from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer |
17 | 17 |
|
@@ -1659,3 +1659,56 @@ async def websocket_handler( |
1659 | 1659 | assert msg.type is aiohttp.WSMsgType.TEXT |
1660 | 1660 | assert msg.data == "success" |
1661 | 1661 | await ws.close() |
| 1662 | + |
| 1663 | + |
| 1664 | +async def test_prepare_after_client_disconnect(aiohttp_client: AiohttpClient) -> None: |
| 1665 | + """Test ConnectionResetError when client disconnects before ws.prepare(). |
| 1666 | +
|
| 1667 | + Reproduces the race condition where: |
| 1668 | + - Client connects and sends a WebSocket upgrade request |
| 1669 | + - Handler starts async work (e.g. authentication) before calling ws.prepare() |
| 1670 | + - Client disconnects while the handler is busy |
| 1671 | + - Handler then calls ws.prepare() → ConnectionResetError (not AssertionError) |
| 1672 | + """ |
| 1673 | + handler_started = asyncio.Event() |
| 1674 | + captured_protocol = None |
| 1675 | + |
| 1676 | + async def handler(request: web.Request) -> web.Response: |
| 1677 | + nonlocal captured_protocol |
| 1678 | + ws = web.WebSocketResponse() |
| 1679 | + captured_protocol = request._protocol |
| 1680 | + handler_started.set() |
| 1681 | + # Simulate async work (e.g., auth check) during which client disconnects. |
| 1682 | + await asyncio.sleep(0) |
| 1683 | + with pytest.raises(ConnectionResetError, match="Transport is not available"): |
| 1684 | + await ws.prepare(request) |
| 1685 | + return web.Response(status=503) |
| 1686 | + |
| 1687 | + app = web.Application() |
| 1688 | + app.router.add_route("GET", "/", handler) |
| 1689 | + client = await aiohttp_client(app) |
| 1690 | + |
| 1691 | + request_task = asyncio.create_task( |
| 1692 | + client.session.get( |
| 1693 | + client.make_url("/"), |
| 1694 | + headers={ |
| 1695 | + hdrs.UPGRADE: "websocket", |
| 1696 | + hdrs.CONNECTION: "Upgrade", |
| 1697 | + hdrs.SEC_WEBSOCKET_KEY: "dGhlIHNhbXBsZSBub25jZQ==", |
| 1698 | + hdrs.SEC_WEBSOCKET_VERSION: "13", |
| 1699 | + }, |
| 1700 | + ) |
| 1701 | + ) |
| 1702 | + |
| 1703 | + # Wait until the handler is running but has not yet called ws.prepare(). |
| 1704 | + await handler_started.wait() |
| 1705 | + assert captured_protocol is not None |
| 1706 | + |
| 1707 | + # Simulate the client disconnecting abruptly. |
| 1708 | + captured_protocol.force_close() |
| 1709 | + |
| 1710 | + # Yield so the handler can resume and hit the ConnectionResetError. |
| 1711 | + await asyncio.sleep(0) |
| 1712 | + |
| 1713 | + with contextlib.suppress(aiohttp.ServerDisconnectedError, aiohttp.ClientConnectionResetError): |
| 1714 | + await request_task |
0 commit comments