From b3e2e07a9b21923357367e08fed4074f93d967df Mon Sep 17 00:00:00 2001 From: David Brochart Date: Thu, 13 Mar 2025 17:40:47 +0100 Subject: [PATCH 01/10] Test wait_readable --- .github/workflows/test.yml | 7 +++++-- tests/test_socket.py | 24 +++++++++++++++++++++++- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f6d8873..e47e860 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -12,6 +12,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: + fail-fast: false matrix: os: [ubuntu-latest, windows-latest, macos-latest] python-version: @@ -32,7 +33,9 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Install dependencies - run: pip install -e ".[test]" + run: | + pip install -e ".[test]" + pip install git+https://github.com/davidbrochart/anyio.git@show-error#egg=anyio --ignore-installed - name: Check with mypy and ruff if: ${{ (matrix.python-version == '3.13') && (matrix.os == 'ubuntu-latest') }} run: | @@ -41,7 +44,7 @@ jobs: ruff check src - name: Run tests if: ${{ !((matrix.python-version == '3.13') && (matrix.os == 'ubuntu-latest')) }} - run: pytest --color=yes -v tests + run: pytest --color=yes -v tests -s - name: Run code coverage if: ${{ (matrix.python-version == '3.13') && (matrix.os == 'ubuntu-latest') }} run: | diff --git a/tests/test_socket.py b/tests/test_socket.py index 56fc0cd..08397e6 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -1,8 +1,9 @@ import json +import socket import pytest import zmq -from anyio import create_task_group, fail_after, move_on_after, sleep, to_thread +from anyio import create_task_group, fail_after, move_on_after, sleep, to_thread, wait_all_tasks_blocked, wait_readable from anyioutils import CancelledError, Future, create_task from zmq_anyio import Poller, Socket @@ -342,3 +343,24 @@ async def test_close(create_bound_pair): await tg.start(b.start) a.close() b.close() + await sleep(0.1) + + +async def test_wait_readable(): + with fail_after(1): + s1, s2 = socket.socketpair() + with s1, s2: + s1.setblocking(False) + s2.setblocking(False) + async with create_task_group() as tg: + tg.start_soon(wait_readable, s2) + await wait_all_tasks_blocked() + await sleep(0.1) + tg.cancel_scope.cancel() + + s1, s2 = socket.socketpair() + with s1, s2: + s1.setblocking(False) + s2.setblocking(False) + s1.send(b"\x00") + await wait_readable(s2) From 6cf2d5192e04e06251c0001a80196b35b35ea015 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Fri, 14 Mar 2025 10:49:59 +0100 Subject: [PATCH 02/10] Don't cancel wait_readable() --- src/zmq_anyio/_socket.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/zmq_anyio/_socket.py b/src/zmq_anyio/_socket.py index 0509db1..acc8210 100644 --- a/src/zmq_anyio/_socket.py +++ b/src/zmq_anyio/_socket.py @@ -895,19 +895,19 @@ async def _start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): self._task_group, exception_handler=ignore_exceptions, ) + wait_readable_task = create_task( + wait_readable(self._shadow_sock), # type: ignore[arg-type] + self._task_group, + exception_handler=ignore_exceptions, + ) tasks = [ - create_task( - wait_readable(self._shadow_sock), # type: ignore[arg-type] - self._task_group, - exception_handler=ignore_exceptions, - ), + wait_readable_task, wait_stopped_task, ] done, pending = await wait( tasks, self._task_group, return_when=FIRST_COMPLETED ) - for task in pending: - task.cancel() + wait_stopped_task.cancel() if wait_stopped_task in done: break await self._handle_events() From 2260e0a28d69d42c3593235078ff511a813e1e29 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Fri, 14 Mar 2025 14:24:21 +0100 Subject: [PATCH 03/10] Keep only relevant tests --- tests/test_socket.py | 646 +++++++++++++++++++++---------------------- 1 file changed, 323 insertions(+), 323 deletions(-) diff --git a/tests/test_socket.py b/tests/test_socket.py index 08397e6..67f7b0e 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -10,329 +10,329 @@ pytestmark = pytest.mark.anyio -async def test_context(context): - a, b = Socket(context, zmq.PAIR), Socket(context, zmq.PAIR) - port = a.bind_to_random_port("tcp://127.0.0.1") - b.connect(f'tcp://127.0.0.1:{port}') - a.send(b"Hello") - assert b.recv() == b"Hello" - async with a, b: - await a.asend(b"Hello").wait() - assert await b.arecv().wait() == b"Hello" - - -async def test_arecv_multipart(create_bound_pair): - a, b = map(Socket, create_bound_pair(zmq.PUSH, zmq.PULL)) - async with b, a: - f = b.arecv_multipart() - assert not f.done() - await a.asend(b"hi").wait() - recvd = await f.wait() - assert recvd == [b"hi"] - - -async def test_arecv(create_bound_pair): - a, b = map(Socket, create_bound_pair(zmq.PUSH, zmq.PULL)) - async with b, a: - f1 = b.arecv() - f2 = b.arecv() - assert not f1.done() - assert not f2.done() - await a.asend_multipart([b"hi", b"there"]).wait() - recvd = await f2.wait() - assert f1.done() - assert f1.result() == b"hi" - assert recvd == b"there" - - -async def test_arecv_json(create_bound_pair): - a, b = map(Socket, create_bound_pair(zmq.PUSH, zmq.PULL)) - async with b, a: - f = b.arecv_json() - assert not f.done() - obj = dict(a=5) - await a.asend_json(obj).wait() - recvd = await f.wait() - assert f.done() - assert f.result() == obj - assert recvd == obj - - -async def test_arecv_send(create_bound_pair): - a, b = map(Socket, create_bound_pair(zmq.REQ, zmq.REP)) - async with b, a: - f = b.arecv() - - def callback(future: Future) -> None: - b.send(b", World!") - - f.add_done_callback(callback) - a.send(b"Hello") - assert await a.arecv().wait() == b", World!" - - -async def test_inproc(sockets): - ctx = zmq.Context() - url = "inproc://test" - a = ctx.socket(zmq.PUSH) - b = ctx.socket(zmq.PULL) - a.linger = 0 - b.linger = 0 - sockets.extend([a, b]) - a.connect(url) - b.bind(url) - b = Socket(b) - async with b: - f = b.arecv() - await sleep(0.1) - a.send(b"hi") - assert await f.wait() == b"hi" - - -@pytest.mark.parametrize("total_threads", [1, 2]) -async def test_start_socket(total_threads, create_bound_pair): - to_thread.current_default_thread_limiter().total_tokens = total_threads - - a, b = map(Socket, create_bound_pair(zmq.REQ, zmq.REP)) - a_started = False - b_started = False - - with pytest.raises(BaseException): - async with b: - b_started = True - with move_on_after(0.1): - async with a: - a_started = True - raise RuntimeError - - assert b_started - assert a_started - - to_thread.current_default_thread_limiter().total_tokens = 40 - -async def test_recv_multipart(create_bound_pair): - a, b = map(Socket, create_bound_pair(zmq.PUSH, zmq.PULL)) - async with b, a: - f = b.arecv_multipart() - await a.asend(b"hi").wait() - assert await f.wait() == [b"hi"] - - -async def test_recv(create_bound_pair): - a, b = map(Socket, create_bound_pair(zmq.PUSH, zmq.PULL)) - async with b, a: - f1 = b.arecv() - f2 = b.arecv() - await a.asend_multipart([b"hi", b"there"]).wait() - assert await f1.wait() == b"hi" - assert await f2.wait() == b"there" - - -@pytest.mark.skipif(not hasattr(zmq, "RCVTIMEO"), reason="requires RCVTIMEO") -async def test_recv_timeout(push_pull): - a, b = map(Socket, push_pull) - async with b, a: - b.rcvtimeo = 100 - f1 = b.arecv() - b.rcvtimeo = 1000 - f2 = b.arecv_multipart() - with pytest.raises(zmq.Again): - await f1.wait() - await a.asend_multipart([b"hi", b"there"]).wait() - recvd = await f2.wait() - assert recvd == [b"hi", b"there"] - - -@pytest.mark.skipif(not hasattr(zmq, "SNDTIMEO"), reason="requires SNDTIMEO") -async def test_send_timeout(socket): - s = socket(zmq.PUSH) - s.sndtimeo = 100 - with pytest.raises(zmq.Again): - await s.send(b"not going anywhere") - - -async def test_recv_string(push_pull): - a, b = map(Socket, push_pull) - async with b, a: - f = b.arecv_string() - msg = "πøøπ" - await a.asend_string(msg).wait() - recvd = await f.wait() - assert recvd == msg - - -async def test_recv_json(push_pull): - a, b = map(Socket, push_pull) - async with b, a: - f = b.arecv_json() - obj = dict(a=5) - await a.asend_json(obj).wait() - recvd = await f.wait() - assert recvd == obj - - -async def test_recv_json_cancelled(push_pull): - async with create_task_group() as tg: - a, b = map(Socket, push_pull) - async with b, a: - f = b.arecv_json() - assert not f.done() - f.cancel(raise_exception=True) - # cycle eventloop to allow cancel events to fire - await sleep(0) - obj = dict(a=5) - await a.asend_json(obj).wait() - with pytest.raises(CancelledError): - recvd = await f.wait() - assert f.cancelled() - assert f.done() - # give it a chance to incorrectly consume the event - events = await b.apoll(timeout=5).wait() - assert events - await sleep(0) - # make sure cancelled recv didn't eat up event - f = b.arecv_json() - with move_on_after(5): - recvd = await f.wait() - assert recvd == obj - - -async def test_recv_pyobj(push_pull): - a, b = map(Socket, push_pull) - async with b, a: - f = b.arecv_pyobj() - obj = dict(a=5) - await a.asend_pyobj(obj).wait() - recvd = await f.wait() - assert recvd == obj - - -async def test_custom_serialize(create_bound_pair): - def serialize(msg): - frames = [] - frames.extend(msg.get("identities", [])) - content = json.dumps(msg["content"]).encode("utf8") - frames.append(content) - return frames - - def deserialize(frames): - identities = frames[:-1] - content = json.loads(frames[-1].decode("utf8")) - return { - "identities": identities, - "content": content, - } - - a, b = map(Socket, create_bound_pair(zmq.DEALER, zmq.ROUTER)) - async with b, a: - - msg = { - "content": { - "a": 5, - "b": "bee", - } - } - await a.asend_serialized(msg, serialize).wait() - recvd = await b.arecv_serialized(deserialize).wait() - assert recvd["content"] == msg["content"] - assert recvd["identities"] - # bounce back, tests identities - await b.asend_serialized(recvd, serialize).wait() - r2 = await a.arecv_serialized(deserialize).wait() - assert r2["content"] == msg["content"] - assert not r2["identities"] - - -@pytest.mark.skip(reason="FIXME: sometimes raises CancelledError") -async def test_custom_serialize_error(dealer_router): - a, b = map(Socket, dealer_router) - async with b, a: - await a.asend(b"not json").wait() - with pytest.raises(TypeError): - await b.arecv_serialized(json.loads).wait() - - -async def test_recv_dontwait(push_pull): - push, pull = map(Socket, push_pull) - async with pull, push: - f = pull.arecv(zmq.DONTWAIT) - with pytest.raises(zmq.Again): - await f.wait() - await push.asend(b"ping").wait() - await pull.apoll().wait() # ensure message will be waiting - msg = await pull.arecv(zmq.DONTWAIT).wait() - assert msg == b"ping" - - -async def test_recv_cancel(push_pull): - a, b = map(Socket, push_pull) - async with b, a: - f1 = b.arecv() - f2 = b.arecv_multipart() - f1.cancel() - assert f1.done() - assert not f2.done() - await a.asend_multipart([b"hi", b"there"]).wait() - recvd = await f2.wait() - assert f1.cancelled() - assert f2.done() - assert recvd == [b"hi", b"there"] - - -async def test_poll(push_pull): - a, b = map(Socket, push_pull) - async with b, a: - f = b.apoll(timeout=0) - await sleep(0.1) - assert f.result() == 0 - - f = b.apoll(timeout=1) - assert not f.done() - evt = await f.wait() - - assert evt == 0 - - f = b.apoll(timeout=1000) - assert not f.done() - await a.asend_multipart([b"hi", b"there"]).wait() - evt = await f.wait() - assert evt == zmq.POLLIN - recvd = await b.arecv_multipart().wait() - assert recvd == [b"hi", b"there"] - - -async def test_poll_base_socket(sockets): - ctx = zmq.Context() - url = "inproc://test" - a = Socket(ctx.socket(zmq.PUSH)) - b = Socket(ctx.socket(zmq.PULL)) - sockets.extend([a, b]) - a.bind(url) - b.connect(url) - - poller = Poller() - poller.register(b, zmq.POLLIN) - - async with create_task_group() as tg: - f = poller.apoll(tg, timeout=1000) - assert not f.done() - a.send_multipart([b"hi", b"there"]) - evt = await f.wait() - assert evt == [(b, zmq.POLLIN)] - recvd = b.recv_multipart() - assert recvd == [b"hi", b"there"] - - -@pytest.mark.skip(reason="FIXME: sometimes raises ZMQError") -async def test_poll_on_closed_socket(push_pull): - a, b = push_pull - b = Socket(b) - async with create_task_group() as tg: - async with b: - f = b.apoll(timeout=1) - await sleep(0.1) - - assert f.done() +#async def test_context(context): +# a, b = Socket(context, zmq.PAIR), Socket(context, zmq.PAIR) +# port = a.bind_to_random_port("tcp://127.0.0.1") +# b.connect(f'tcp://127.0.0.1:{port}') +# a.send(b"Hello") +# assert b.recv() == b"Hello" +# async with a, b: +# await a.asend(b"Hello").wait() +# assert await b.arecv().wait() == b"Hello" +# +# +#async def test_arecv_multipart(create_bound_pair): +# a, b = map(Socket, create_bound_pair(zmq.PUSH, zmq.PULL)) +# async with b, a: +# f = b.arecv_multipart() +# assert not f.done() +# await a.asend(b"hi").wait() +# recvd = await f.wait() +# assert recvd == [b"hi"] +# +# +#async def test_arecv(create_bound_pair): +# a, b = map(Socket, create_bound_pair(zmq.PUSH, zmq.PULL)) +# async with b, a: +# f1 = b.arecv() +# f2 = b.arecv() +# assert not f1.done() +# assert not f2.done() +# await a.asend_multipart([b"hi", b"there"]).wait() +# recvd = await f2.wait() +# assert f1.done() +# assert f1.result() == b"hi" +# assert recvd == b"there" +# +# +#async def test_arecv_json(create_bound_pair): +# a, b = map(Socket, create_bound_pair(zmq.PUSH, zmq.PULL)) +# async with b, a: +# f = b.arecv_json() +# assert not f.done() +# obj = dict(a=5) +# await a.asend_json(obj).wait() +# recvd = await f.wait() +# assert f.done() +# assert f.result() == obj +# assert recvd == obj +# +# +#async def test_arecv_send(create_bound_pair): +# a, b = map(Socket, create_bound_pair(zmq.REQ, zmq.REP)) +# async with b, a: +# f = b.arecv() +# +# def callback(future: Future) -> None: +# b.send(b", World!") +# +# f.add_done_callback(callback) +# a.send(b"Hello") +# assert await a.arecv().wait() == b", World!" +# +# +#async def test_inproc(sockets): +# ctx = zmq.Context() +# url = "inproc://test" +# a = ctx.socket(zmq.PUSH) +# b = ctx.socket(zmq.PULL) +# a.linger = 0 +# b.linger = 0 +# sockets.extend([a, b]) +# a.connect(url) +# b.bind(url) +# b = Socket(b) +# async with b: +# f = b.arecv() +# await sleep(0.1) +# a.send(b"hi") +# assert await f.wait() == b"hi" +# +# +#@pytest.mark.parametrize("total_threads", [1, 2]) +#async def test_start_socket(total_threads, create_bound_pair): +# to_thread.current_default_thread_limiter().total_tokens = total_threads +# +# a, b = map(Socket, create_bound_pair(zmq.REQ, zmq.REP)) +# a_started = False +# b_started = False +# +# with pytest.raises(BaseException): +# async with b: +# b_started = True +# with move_on_after(0.1): +# async with a: +# a_started = True +# raise RuntimeError +# +# assert b_started +# assert a_started +# +# to_thread.current_default_thread_limiter().total_tokens = 40 +# +#async def test_recv_multipart(create_bound_pair): +# a, b = map(Socket, create_bound_pair(zmq.PUSH, zmq.PULL)) +# async with b, a: +# f = b.arecv_multipart() +# await a.asend(b"hi").wait() +# assert await f.wait() == [b"hi"] +# +# +#async def test_recv(create_bound_pair): +# a, b = map(Socket, create_bound_pair(zmq.PUSH, zmq.PULL)) +# async with b, a: +# f1 = b.arecv() +# f2 = b.arecv() +# await a.asend_multipart([b"hi", b"there"]).wait() +# assert await f1.wait() == b"hi" +# assert await f2.wait() == b"there" +# +# +#@pytest.mark.skipif(not hasattr(zmq, "RCVTIMEO"), reason="requires RCVTIMEO") +#async def test_recv_timeout(push_pull): +# a, b = map(Socket, push_pull) +# async with b, a: +# b.rcvtimeo = 100 +# f1 = b.arecv() +# b.rcvtimeo = 1000 +# f2 = b.arecv_multipart() +# with pytest.raises(zmq.Again): +# await f1.wait() +# await a.asend_multipart([b"hi", b"there"]).wait() +# recvd = await f2.wait() +# assert recvd == [b"hi", b"there"] +# +# +#@pytest.mark.skipif(not hasattr(zmq, "SNDTIMEO"), reason="requires SNDTIMEO") +#async def test_send_timeout(socket): +# s = socket(zmq.PUSH) +# s.sndtimeo = 100 +# with pytest.raises(zmq.Again): +# await s.send(b"not going anywhere") +# +# +#async def test_recv_string(push_pull): +# a, b = map(Socket, push_pull) +# async with b, a: +# f = b.arecv_string() +# msg = "πøøπ" +# await a.asend_string(msg).wait() +# recvd = await f.wait() +# assert recvd == msg +# +# +#async def test_recv_json(push_pull): +# a, b = map(Socket, push_pull) +# async with b, a: +# f = b.arecv_json() +# obj = dict(a=5) +# await a.asend_json(obj).wait() +# recvd = await f.wait() +# assert recvd == obj +# +# +#async def test_recv_json_cancelled(push_pull): +# async with create_task_group() as tg: +# a, b = map(Socket, push_pull) +# async with b, a: +# f = b.arecv_json() +# assert not f.done() +# f.cancel(raise_exception=True) +# # cycle eventloop to allow cancel events to fire +# await sleep(0) +# obj = dict(a=5) +# await a.asend_json(obj).wait() +# with pytest.raises(CancelledError): +# recvd = await f.wait() +# assert f.cancelled() +# assert f.done() +# # give it a chance to incorrectly consume the event +# events = await b.apoll(timeout=5).wait() +# assert events +# await sleep(0) +# # make sure cancelled recv didn't eat up event +# f = b.arecv_json() +# with move_on_after(5): +# recvd = await f.wait() +# assert recvd == obj +# +# +#async def test_recv_pyobj(push_pull): +# a, b = map(Socket, push_pull) +# async with b, a: +# f = b.arecv_pyobj() +# obj = dict(a=5) +# await a.asend_pyobj(obj).wait() +# recvd = await f.wait() +# assert recvd == obj +# +# +#async def test_custom_serialize(create_bound_pair): +# def serialize(msg): +# frames = [] +# frames.extend(msg.get("identities", [])) +# content = json.dumps(msg["content"]).encode("utf8") +# frames.append(content) +# return frames +# +# def deserialize(frames): +# identities = frames[:-1] +# content = json.loads(frames[-1].decode("utf8")) +# return { +# "identities": identities, +# "content": content, +# } +# +# a, b = map(Socket, create_bound_pair(zmq.DEALER, zmq.ROUTER)) +# async with b, a: +# +# msg = { +# "content": { +# "a": 5, +# "b": "bee", +# } +# } +# await a.asend_serialized(msg, serialize).wait() +# recvd = await b.arecv_serialized(deserialize).wait() +# assert recvd["content"] == msg["content"] +# assert recvd["identities"] +# # bounce back, tests identities +# await b.asend_serialized(recvd, serialize).wait() +# r2 = await a.arecv_serialized(deserialize).wait() +# assert r2["content"] == msg["content"] +# assert not r2["identities"] +# +# +#@pytest.mark.skip(reason="FIXME: sometimes raises CancelledError") +#async def test_custom_serialize_error(dealer_router): +# a, b = map(Socket, dealer_router) +# async with b, a: +# await a.asend(b"not json").wait() +# with pytest.raises(TypeError): +# await b.arecv_serialized(json.loads).wait() +# +# +#async def test_recv_dontwait(push_pull): +# push, pull = map(Socket, push_pull) +# async with pull, push: +# f = pull.arecv(zmq.DONTWAIT) +# with pytest.raises(zmq.Again): +# await f.wait() +# await push.asend(b"ping").wait() +# await pull.apoll().wait() # ensure message will be waiting +# msg = await pull.arecv(zmq.DONTWAIT).wait() +# assert msg == b"ping" +# +# +#async def test_recv_cancel(push_pull): +# a, b = map(Socket, push_pull) +# async with b, a: +# f1 = b.arecv() +# f2 = b.arecv_multipart() +# f1.cancel() +# assert f1.done() +# assert not f2.done() +# await a.asend_multipart([b"hi", b"there"]).wait() +# recvd = await f2.wait() +# assert f1.cancelled() +# assert f2.done() +# assert recvd == [b"hi", b"there"] +# +# +#async def test_poll(push_pull): +# a, b = map(Socket, push_pull) +# async with b, a: +# f = b.apoll(timeout=0) +# await sleep(0.1) +# assert f.result() == 0 +# +# f = b.apoll(timeout=1) +# assert not f.done() +# evt = await f.wait() +# +# assert evt == 0 +# +# f = b.apoll(timeout=1000) +# assert not f.done() +# await a.asend_multipart([b"hi", b"there"]).wait() +# evt = await f.wait() +# assert evt == zmq.POLLIN +# recvd = await b.arecv_multipart().wait() +# assert recvd == [b"hi", b"there"] +# +# +#async def test_poll_base_socket(sockets): +# ctx = zmq.Context() +# url = "inproc://test" +# a = Socket(ctx.socket(zmq.PUSH)) +# b = Socket(ctx.socket(zmq.PULL)) +# sockets.extend([a, b]) +# a.bind(url) +# b.connect(url) +# +# poller = Poller() +# poller.register(b, zmq.POLLIN) +# +# async with create_task_group() as tg: +# f = poller.apoll(tg, timeout=1000) +# assert not f.done() +# a.send_multipart([b"hi", b"there"]) +# evt = await f.wait() +# assert evt == [(b, zmq.POLLIN)] +# recvd = b.recv_multipart() +# assert recvd == [b"hi", b"there"] +# +# +#@pytest.mark.skip(reason="FIXME: sometimes raises ZMQError") +#async def test_poll_on_closed_socket(push_pull): +# a, b = push_pull +# b = Socket(b) +# async with create_task_group() as tg: +# async with b: +# f = b.apoll(timeout=1) +# await sleep(0.1) +# +# assert f.done() async def test_close(create_bound_pair): From 0ade30df16c253d2bd2fa485f61e8dab86457764 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Fri, 14 Mar 2025 14:58:51 +0100 Subject: [PATCH 04/10] Wrap fileno() method --- src/zmq_anyio/_socket.py | 6 ++++++ tests/test_socket.py | 4 ++++ 2 files changed, 10 insertions(+) diff --git a/src/zmq_anyio/_socket.py b/src/zmq_anyio/_socket.py index acc8210..45db64c 100644 --- a/src/zmq_anyio/_socket.py +++ b/src/zmq_anyio/_socket.py @@ -194,6 +194,12 @@ def __init__( self._task_group = task_group self.__stack = None + def fileno(self) -> int: + try: + return super().fileno() + except zmq.error.ZMQError: + return -1 + def get(self, key): result = super().get(key) if key == EVENTS: diff --git a/tests/test_socket.py b/tests/test_socket.py index 67f7b0e..9c322ab 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -341,8 +341,12 @@ async def test_close(create_bound_pair): async with create_task_group() as tg: await tg.start(a.start) await tg.start(b.start) + print(f"{a.fileno()=}") + print(f"{b.fileno()=}") a.close() b.close() + print(f"{a.fileno()=}") + print(f"{b.fileno()=}") await sleep(0.1) From 76582a59c648bbd2b728fdb699448f9a9cb630d4 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Fri, 14 Mar 2025 15:10:01 +0100 Subject: [PATCH 05/10] Don't use shadow socket in wait_readable() --- src/zmq_anyio/_socket.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zmq_anyio/_socket.py b/src/zmq_anyio/_socket.py index 45db64c..4029608 100644 --- a/src/zmq_anyio/_socket.py +++ b/src/zmq_anyio/_socket.py @@ -902,7 +902,7 @@ async def _start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): exception_handler=ignore_exceptions, ) wait_readable_task = create_task( - wait_readable(self._shadow_sock), # type: ignore[arg-type] + wait_readable(self), # type: ignore[arg-type] self._task_group, exception_handler=ignore_exceptions, ) From 6cdb4016cb8e5160f58d3e9b0be8255a64558613 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Fri, 14 Mar 2025 16:06:35 +0100 Subject: [PATCH 06/10] - --- src/zmq_anyio/_socket.py | 14 +- tests/test_socket.py | 646 +++++++++++++++++++-------------------- 2 files changed, 330 insertions(+), 330 deletions(-) diff --git a/src/zmq_anyio/_socket.py b/src/zmq_anyio/_socket.py index 4029608..e861ee0 100644 --- a/src/zmq_anyio/_socket.py +++ b/src/zmq_anyio/_socket.py @@ -901,19 +901,19 @@ async def _start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): self._task_group, exception_handler=ignore_exceptions, ) - wait_readable_task = create_task( - wait_readable(self), # type: ignore[arg-type] - self._task_group, - exception_handler=ignore_exceptions, - ) tasks = [ - wait_readable_task, + create_task( + wait_readable(self), # type: ignore[arg-type] + self._task_group, + exception_handler=ignore_exceptions, + ), wait_stopped_task, ] done, pending = await wait( tasks, self._task_group, return_when=FIRST_COMPLETED ) - wait_stopped_task.cancel() + for task in pending: + task.cancel() if wait_stopped_task in done: break await self._handle_events() diff --git a/tests/test_socket.py b/tests/test_socket.py index 9c322ab..d89bc73 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -10,329 +10,329 @@ pytestmark = pytest.mark.anyio -#async def test_context(context): -# a, b = Socket(context, zmq.PAIR), Socket(context, zmq.PAIR) -# port = a.bind_to_random_port("tcp://127.0.0.1") -# b.connect(f'tcp://127.0.0.1:{port}') -# a.send(b"Hello") -# assert b.recv() == b"Hello" -# async with a, b: -# await a.asend(b"Hello").wait() -# assert await b.arecv().wait() == b"Hello" -# -# -#async def test_arecv_multipart(create_bound_pair): -# a, b = map(Socket, create_bound_pair(zmq.PUSH, zmq.PULL)) -# async with b, a: -# f = b.arecv_multipart() -# assert not f.done() -# await a.asend(b"hi").wait() -# recvd = await f.wait() -# assert recvd == [b"hi"] -# -# -#async def test_arecv(create_bound_pair): -# a, b = map(Socket, create_bound_pair(zmq.PUSH, zmq.PULL)) -# async with b, a: -# f1 = b.arecv() -# f2 = b.arecv() -# assert not f1.done() -# assert not f2.done() -# await a.asend_multipart([b"hi", b"there"]).wait() -# recvd = await f2.wait() -# assert f1.done() -# assert f1.result() == b"hi" -# assert recvd == b"there" -# -# -#async def test_arecv_json(create_bound_pair): -# a, b = map(Socket, create_bound_pair(zmq.PUSH, zmq.PULL)) -# async with b, a: -# f = b.arecv_json() -# assert not f.done() -# obj = dict(a=5) -# await a.asend_json(obj).wait() -# recvd = await f.wait() -# assert f.done() -# assert f.result() == obj -# assert recvd == obj -# -# -#async def test_arecv_send(create_bound_pair): -# a, b = map(Socket, create_bound_pair(zmq.REQ, zmq.REP)) -# async with b, a: -# f = b.arecv() -# -# def callback(future: Future) -> None: -# b.send(b", World!") -# -# f.add_done_callback(callback) -# a.send(b"Hello") -# assert await a.arecv().wait() == b", World!" -# -# -#async def test_inproc(sockets): -# ctx = zmq.Context() -# url = "inproc://test" -# a = ctx.socket(zmq.PUSH) -# b = ctx.socket(zmq.PULL) -# a.linger = 0 -# b.linger = 0 -# sockets.extend([a, b]) -# a.connect(url) -# b.bind(url) -# b = Socket(b) -# async with b: -# f = b.arecv() -# await sleep(0.1) -# a.send(b"hi") -# assert await f.wait() == b"hi" -# -# -#@pytest.mark.parametrize("total_threads", [1, 2]) -#async def test_start_socket(total_threads, create_bound_pair): -# to_thread.current_default_thread_limiter().total_tokens = total_threads -# -# a, b = map(Socket, create_bound_pair(zmq.REQ, zmq.REP)) -# a_started = False -# b_started = False -# -# with pytest.raises(BaseException): -# async with b: -# b_started = True -# with move_on_after(0.1): -# async with a: -# a_started = True -# raise RuntimeError -# -# assert b_started -# assert a_started -# -# to_thread.current_default_thread_limiter().total_tokens = 40 -# -#async def test_recv_multipart(create_bound_pair): -# a, b = map(Socket, create_bound_pair(zmq.PUSH, zmq.PULL)) -# async with b, a: -# f = b.arecv_multipart() -# await a.asend(b"hi").wait() -# assert await f.wait() == [b"hi"] -# -# -#async def test_recv(create_bound_pair): -# a, b = map(Socket, create_bound_pair(zmq.PUSH, zmq.PULL)) -# async with b, a: -# f1 = b.arecv() -# f2 = b.arecv() -# await a.asend_multipart([b"hi", b"there"]).wait() -# assert await f1.wait() == b"hi" -# assert await f2.wait() == b"there" -# -# -#@pytest.mark.skipif(not hasattr(zmq, "RCVTIMEO"), reason="requires RCVTIMEO") -#async def test_recv_timeout(push_pull): -# a, b = map(Socket, push_pull) -# async with b, a: -# b.rcvtimeo = 100 -# f1 = b.arecv() -# b.rcvtimeo = 1000 -# f2 = b.arecv_multipart() -# with pytest.raises(zmq.Again): -# await f1.wait() -# await a.asend_multipart([b"hi", b"there"]).wait() -# recvd = await f2.wait() -# assert recvd == [b"hi", b"there"] -# -# -#@pytest.mark.skipif(not hasattr(zmq, "SNDTIMEO"), reason="requires SNDTIMEO") -#async def test_send_timeout(socket): -# s = socket(zmq.PUSH) -# s.sndtimeo = 100 -# with pytest.raises(zmq.Again): -# await s.send(b"not going anywhere") -# -# -#async def test_recv_string(push_pull): -# a, b = map(Socket, push_pull) -# async with b, a: -# f = b.arecv_string() -# msg = "πøøπ" -# await a.asend_string(msg).wait() -# recvd = await f.wait() -# assert recvd == msg -# -# -#async def test_recv_json(push_pull): -# a, b = map(Socket, push_pull) -# async with b, a: -# f = b.arecv_json() -# obj = dict(a=5) -# await a.asend_json(obj).wait() -# recvd = await f.wait() -# assert recvd == obj -# -# -#async def test_recv_json_cancelled(push_pull): -# async with create_task_group() as tg: -# a, b = map(Socket, push_pull) -# async with b, a: -# f = b.arecv_json() -# assert not f.done() -# f.cancel(raise_exception=True) -# # cycle eventloop to allow cancel events to fire -# await sleep(0) -# obj = dict(a=5) -# await a.asend_json(obj).wait() -# with pytest.raises(CancelledError): -# recvd = await f.wait() -# assert f.cancelled() -# assert f.done() -# # give it a chance to incorrectly consume the event -# events = await b.apoll(timeout=5).wait() -# assert events -# await sleep(0) -# # make sure cancelled recv didn't eat up event -# f = b.arecv_json() -# with move_on_after(5): -# recvd = await f.wait() -# assert recvd == obj -# -# -#async def test_recv_pyobj(push_pull): -# a, b = map(Socket, push_pull) -# async with b, a: -# f = b.arecv_pyobj() -# obj = dict(a=5) -# await a.asend_pyobj(obj).wait() -# recvd = await f.wait() -# assert recvd == obj -# -# -#async def test_custom_serialize(create_bound_pair): -# def serialize(msg): -# frames = [] -# frames.extend(msg.get("identities", [])) -# content = json.dumps(msg["content"]).encode("utf8") -# frames.append(content) -# return frames -# -# def deserialize(frames): -# identities = frames[:-1] -# content = json.loads(frames[-1].decode("utf8")) -# return { -# "identities": identities, -# "content": content, -# } -# -# a, b = map(Socket, create_bound_pair(zmq.DEALER, zmq.ROUTER)) -# async with b, a: -# -# msg = { -# "content": { -# "a": 5, -# "b": "bee", -# } -# } -# await a.asend_serialized(msg, serialize).wait() -# recvd = await b.arecv_serialized(deserialize).wait() -# assert recvd["content"] == msg["content"] -# assert recvd["identities"] -# # bounce back, tests identities -# await b.asend_serialized(recvd, serialize).wait() -# r2 = await a.arecv_serialized(deserialize).wait() -# assert r2["content"] == msg["content"] -# assert not r2["identities"] -# -# -#@pytest.mark.skip(reason="FIXME: sometimes raises CancelledError") -#async def test_custom_serialize_error(dealer_router): -# a, b = map(Socket, dealer_router) -# async with b, a: -# await a.asend(b"not json").wait() -# with pytest.raises(TypeError): -# await b.arecv_serialized(json.loads).wait() -# -# -#async def test_recv_dontwait(push_pull): -# push, pull = map(Socket, push_pull) -# async with pull, push: -# f = pull.arecv(zmq.DONTWAIT) -# with pytest.raises(zmq.Again): -# await f.wait() -# await push.asend(b"ping").wait() -# await pull.apoll().wait() # ensure message will be waiting -# msg = await pull.arecv(zmq.DONTWAIT).wait() -# assert msg == b"ping" -# -# -#async def test_recv_cancel(push_pull): -# a, b = map(Socket, push_pull) -# async with b, a: -# f1 = b.arecv() -# f2 = b.arecv_multipart() -# f1.cancel() -# assert f1.done() -# assert not f2.done() -# await a.asend_multipart([b"hi", b"there"]).wait() -# recvd = await f2.wait() -# assert f1.cancelled() -# assert f2.done() -# assert recvd == [b"hi", b"there"] -# -# -#async def test_poll(push_pull): -# a, b = map(Socket, push_pull) -# async with b, a: -# f = b.apoll(timeout=0) -# await sleep(0.1) -# assert f.result() == 0 -# -# f = b.apoll(timeout=1) -# assert not f.done() -# evt = await f.wait() -# -# assert evt == 0 -# -# f = b.apoll(timeout=1000) -# assert not f.done() -# await a.asend_multipart([b"hi", b"there"]).wait() -# evt = await f.wait() -# assert evt == zmq.POLLIN -# recvd = await b.arecv_multipart().wait() -# assert recvd == [b"hi", b"there"] -# -# -#async def test_poll_base_socket(sockets): -# ctx = zmq.Context() -# url = "inproc://test" -# a = Socket(ctx.socket(zmq.PUSH)) -# b = Socket(ctx.socket(zmq.PULL)) -# sockets.extend([a, b]) -# a.bind(url) -# b.connect(url) -# -# poller = Poller() -# poller.register(b, zmq.POLLIN) -# -# async with create_task_group() as tg: -# f = poller.apoll(tg, timeout=1000) -# assert not f.done() -# a.send_multipart([b"hi", b"there"]) -# evt = await f.wait() -# assert evt == [(b, zmq.POLLIN)] -# recvd = b.recv_multipart() -# assert recvd == [b"hi", b"there"] -# -# -#@pytest.mark.skip(reason="FIXME: sometimes raises ZMQError") -#async def test_poll_on_closed_socket(push_pull): -# a, b = push_pull -# b = Socket(b) -# async with create_task_group() as tg: -# async with b: -# f = b.apoll(timeout=1) -# await sleep(0.1) -# -# assert f.done() +async def test_context(context): + a, b = Socket(context, zmq.PAIR), Socket(context, zmq.PAIR) + port = a.bind_to_random_port("tcp://127.0.0.1") + b.connect(f'tcp://127.0.0.1:{port}') + a.send(b"Hello") + assert b.recv() == b"Hello" + async with a, b: + await a.asend(b"Hello").wait() + assert await b.arecv().wait() == b"Hello" + + +async def test_arecv_multipart(create_bound_pair): + a, b = map(Socket, create_bound_pair(zmq.PUSH, zmq.PULL)) + async with b, a: + f = b.arecv_multipart() + assert not f.done() + await a.asend(b"hi").wait() + recvd = await f.wait() + assert recvd == [b"hi"] + + +async def test_arecv(create_bound_pair): + a, b = map(Socket, create_bound_pair(zmq.PUSH, zmq.PULL)) + async with b, a: + f1 = b.arecv() + f2 = b.arecv() + assert not f1.done() + assert not f2.done() + await a.asend_multipart([b"hi", b"there"]).wait() + recvd = await f2.wait() + assert f1.done() + assert f1.result() == b"hi" + assert recvd == b"there" + + +async def test_arecv_json(create_bound_pair): + a, b = map(Socket, create_bound_pair(zmq.PUSH, zmq.PULL)) + async with b, a: + f = b.arecv_json() + assert not f.done() + obj = dict(a=5) + await a.asend_json(obj).wait() + recvd = await f.wait() + assert f.done() + assert f.result() == obj + assert recvd == obj + + +async def test_arecv_send(create_bound_pair): + a, b = map(Socket, create_bound_pair(zmq.REQ, zmq.REP)) + async with b, a: + f = b.arecv() + + def callback(future: Future) -> None: + b.send(b", World!") + + f.add_done_callback(callback) + a.send(b"Hello") + assert await a.arecv().wait() == b", World!" + + +async def test_inproc(sockets): + ctx = zmq.Context() + url = "inproc://test" + a = ctx.socket(zmq.PUSH) + b = ctx.socket(zmq.PULL) + a.linger = 0 + b.linger = 0 + sockets.extend([a, b]) + a.connect(url) + b.bind(url) + b = Socket(b) + async with b: + f = b.arecv() + await sleep(0.1) + a.send(b"hi") + assert await f.wait() == b"hi" + + +@pytest.mark.parametrize("total_threads", [1, 2]) +async def test_start_socket(total_threads, create_bound_pair): + to_thread.current_default_thread_limiter().total_tokens = total_threads + + a, b = map(Socket, create_bound_pair(zmq.REQ, zmq.REP)) + a_started = False + b_started = False + + with pytest.raises(BaseException): + async with b: + b_started = True + with move_on_after(0.1): + async with a: + a_started = True + raise RuntimeError + + assert b_started + assert a_started + + to_thread.current_default_thread_limiter().total_tokens = 40 + +async def test_recv_multipart(create_bound_pair): + a, b = map(Socket, create_bound_pair(zmq.PUSH, zmq.PULL)) + async with b, a: + f = b.arecv_multipart() + await a.asend(b"hi").wait() + assert await f.wait() == [b"hi"] + + +async def test_recv(create_bound_pair): + a, b = map(Socket, create_bound_pair(zmq.PUSH, zmq.PULL)) + async with b, a: + f1 = b.arecv() + f2 = b.arecv() + await a.asend_multipart([b"hi", b"there"]).wait() + assert await f1.wait() == b"hi" + assert await f2.wait() == b"there" + + +@pytest.mark.skipif(not hasattr(zmq, "RCVTIMEO"), reason="requires RCVTIMEO") +async def test_recv_timeout(push_pull): + a, b = map(Socket, push_pull) + async with b, a: + b.rcvtimeo = 100 + f1 = b.arecv() + b.rcvtimeo = 1000 + f2 = b.arecv_multipart() + with pytest.raises(zmq.Again): + await f1.wait() + await a.asend_multipart([b"hi", b"there"]).wait() + recvd = await f2.wait() + assert recvd == [b"hi", b"there"] + + +@pytest.mark.skipif(not hasattr(zmq, "SNDTIMEO"), reason="requires SNDTIMEO") +async def test_send_timeout(socket): + s = socket(zmq.PUSH) + s.sndtimeo = 100 + with pytest.raises(zmq.Again): + await s.send(b"not going anywhere") + + +async def test_recv_string(push_pull): + a, b = map(Socket, push_pull) + async with b, a: + f = b.arecv_string() + msg = "πøøπ" + await a.asend_string(msg).wait() + recvd = await f.wait() + assert recvd == msg + + +async def test_recv_json(push_pull): + a, b = map(Socket, push_pull) + async with b, a: + f = b.arecv_json() + obj = dict(a=5) + await a.asend_json(obj).wait() + recvd = await f.wait() + assert recvd == obj + + +async def test_recv_json_cancelled(push_pull): + async with create_task_group() as tg: + a, b = map(Socket, push_pull) + async with b, a: + f = b.arecv_json() + assert not f.done() + f.cancel(raise_exception=True) + # cycle eventloop to allow cancel events to fire + await sleep(0) + obj = dict(a=5) + await a.asend_json(obj).wait() + with pytest.raises(CancelledError): + recvd = await f.wait() + assert f.cancelled() + assert f.done() + # give it a chance to incorrectly consume the event + events = await b.apoll(timeout=5).wait() + assert events + await sleep(0) + # make sure cancelled recv didn't eat up event + f = b.arecv_json() + with move_on_after(5): + recvd = await f.wait() + assert recvd == obj + + +async def test_recv_pyobj(push_pull): + a, b = map(Socket, push_pull) + async with b, a: + f = b.arecv_pyobj() + obj = dict(a=5) + await a.asend_pyobj(obj).wait() + recvd = await f.wait() + assert recvd == obj + + +async def test_custom_serialize(create_bound_pair): + def serialize(msg): + frames = [] + frames.extend(msg.get("identities", [])) + content = json.dumps(msg["content"]).encode("utf8") + frames.append(content) + return frames + + def deserialize(frames): + identities = frames[:-1] + content = json.loads(frames[-1].decode("utf8")) + return { + "identities": identities, + "content": content, + } + + a, b = map(Socket, create_bound_pair(zmq.DEALER, zmq.ROUTER)) + async with b, a: + + msg = { + "content": { + "a": 5, + "b": "bee", + } + } + await a.asend_serialized(msg, serialize).wait() + recvd = await b.arecv_serialized(deserialize).wait() + assert recvd["content"] == msg["content"] + assert recvd["identities"] + # bounce back, tests identities + await b.asend_serialized(recvd, serialize).wait() + r2 = await a.arecv_serialized(deserialize).wait() + assert r2["content"] == msg["content"] + assert not r2["identities"] + + +@pytest.mark.skip(reason="FIXME: sometimes raises CancelledError") +async def test_custom_serialize_error(dealer_router): + a, b = map(Socket, dealer_router) + async with b, a: + await a.asend(b"not json").wait() + with pytest.raises(TypeError): + await b.arecv_serialized(json.loads).wait() + + +async def test_recv_dontwait(push_pull): + push, pull = map(Socket, push_pull) + async with pull, push: + f = pull.arecv(zmq.DONTWAIT) + with pytest.raises(zmq.Again): + await f.wait() + await push.asend(b"ping").wait() + await pull.apoll().wait() # ensure message will be waiting + msg = await pull.arecv(zmq.DONTWAIT).wait() + assert msg == b"ping" + + +async def test_recv_cancel(push_pull): + a, b = map(Socket, push_pull) + async with b, a: + f1 = b.arecv() + f2 = b.arecv_multipart() + f1.cancel() + assert f1.done() + assert not f2.done() + await a.asend_multipart([b"hi", b"there"]).wait() + recvd = await f2.wait() + assert f1.cancelled() + assert f2.done() + assert recvd == [b"hi", b"there"] + + +async def test_poll(push_pull): + a, b = map(Socket, push_pull) + async with b, a: + f = b.apoll(timeout=0) + await sleep(0.1) + assert f.result() == 0 + + f = b.apoll(timeout=1) + assert not f.done() + evt = await f.wait() + + assert evt == 0 + + f = b.apoll(timeout=1000) + assert not f.done() + await a.asend_multipart([b"hi", b"there"]).wait() + evt = await f.wait() + assert evt == zmq.POLLIN + recvd = await b.arecv_multipart().wait() + assert recvd == [b"hi", b"there"] + + +async def test_poll_base_socket(sockets): + ctx = zmq.Context() + url = "inproc://test" + a = Socket(ctx.socket(zmq.PUSH)) + b = Socket(ctx.socket(zmq.PULL)) + sockets.extend([a, b]) + a.bind(url) + b.connect(url) + + poller = Poller() + poller.register(b, zmq.POLLIN) + + async with create_task_group() as tg: + f = poller.apoll(tg, timeout=1000) + assert not f.done() + a.send_multipart([b"hi", b"there"]) + evt = await f.wait() + assert evt == [(b, zmq.POLLIN)] + recvd = b.recv_multipart() + assert recvd == [b"hi", b"there"] + + +@pytest.mark.skip(reason="FIXME: sometimes raises ZMQError") +async def test_poll_on_closed_socket(push_pull): + a, b = push_pull + b = Socket(b) + async with create_task_group() as tg: + async with b: + f = b.apoll(timeout=1) + await sleep(0.1) + + assert f.done() async def test_close(create_bound_pair): From 55024088ad283631e9f5aca0d3d3e0299b29387e Mon Sep 17 00:00:00 2001 From: David Brochart Date: Sun, 16 Mar 2025 17:12:28 +0100 Subject: [PATCH 07/10] - --- .github/workflows/test.yml | 4 +- pyproject.toml | 1 + src/zmq_anyio/_selector_thread.py | 387 ++++++++++++++++++++++++++++++ src/zmq_anyio/_socket.py | 3 + tests/test_socket.py | 21 +- 5 files changed, 411 insertions(+), 5 deletions(-) create mode 100644 src/zmq_anyio/_selector_thread.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e47e860..b4f97e5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -33,9 +33,7 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Install dependencies - run: | - pip install -e ".[test]" - pip install git+https://github.com/davidbrochart/anyio.git@show-error#egg=anyio --ignore-installed + run: pip install -e ".[test]" - name: Check with mypy and ruff if: ${{ (matrix.python-version == '3.13') && (matrix.os == 'ubuntu-latest') }} run: | diff --git a/pyproject.toml b/pyproject.toml index b76bcca..019fc3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ requires-python = ">= 3.9" dependencies = [ "anyio >=4.8.0,<5.0.0", "anyioutils >=0.7.1,<0.8.0", + "sniffio", "pyzmq >=26.0.0,<27.0.0", ] diff --git a/src/zmq_anyio/_selector_thread.py b/src/zmq_anyio/_selector_thread.py new file mode 100644 index 0000000..68bca92 --- /dev/null +++ b/src/zmq_anyio/_selector_thread.py @@ -0,0 +1,387 @@ +"""Ensure asyncio selector methods (add_reader, etc.) are available. +Running select in a thread and defining these methods on the running event loop. +Originally in tornado.platform.asyncio. +Redistributed under license Apache-2.0 +""" + +from __future__ import annotations + +import asyncio +import atexit +import errno +import functools +import select +import socket +import sys +import threading +import typing +from typing import ( + Any, + Callable, + Union, +) +from weakref import WeakKeyDictionary + +from sniffio import current_async_library + +if typing.TYPE_CHECKING: + from typing_extensions import Protocol + + class _HasFileno(Protocol): + def fileno(self) -> int: + pass + + _FileDescriptorLike = Union[int, _HasFileno] + + +# Collection of selector thread event loops to shut down on exit. +_selector_loops: set[SelectorThread] = set() + + +def _atexit_callback() -> None: + for loop in _selector_loops: + with loop._select_cond: + loop._closing_selector = True + loop._select_cond.notify() + try: + loop._waker_w.send(b"a") + except BlockingIOError: + pass + # If we don't join our (daemon) thread here, we may get a deadlock + # during interpreter shutdown. I don't really understand why. This + # deadlock happens every time in CI (both travis and appveyor) but + # I've never been able to reproduce locally. + assert loop._thread is not None + loop._thread.join() + _selector_loops.clear() + + +atexit.register(_atexit_callback) + + +# SelectorThread from tornado 6.4.0 + + +class SelectorThread: + """Define ``add_reader`` methods to be called in a background select thread. + + Instances of this class start a second thread to run a selector. + This thread is completely hidden from the user; + all callbacks are run on the wrapped event loop's thread. + + Typically used via ``AddThreadSelectorEventLoop``, + but can be attached to a running asyncio loop. + """ + + _closed = False + + def __init__(self, real_loop: asyncio.AbstractEventLoop) -> None: + self._real_loop = real_loop + + self._select_cond = threading.Condition() + self._select_args: ( + tuple[list[_FileDescriptorLike], list[_FileDescriptorLike]] | None + ) = None + self._closing_selector = False + self._thread: threading.Thread | None = None + self._thread_manager_handle = self._thread_manager() + + async def thread_manager_anext() -> None: + # the anext builtin wasn't added until 3.10. We just need to iterate + # this generator one step. + await self._thread_manager_handle.__anext__() + + # When the loop starts, start the thread. Not too soon because we can't + # clean up if we get to this point but the event loop is closed without + # starting. + self._real_loop.call_soon( + lambda: self._real_loop.create_task(thread_manager_anext()) + ) + + self._readers: dict[_FileDescriptorLike, Callable] = {} + self._writers: dict[_FileDescriptorLike, Callable] = {} + + # Writing to _waker_w will wake up the selector thread, which + # watches for _waker_r to be readable. + self._waker_r, self._waker_w = socket.socketpair() + self._waker_r.setblocking(False) + self._waker_w.setblocking(False) + _selector_loops.add(self) + self.add_reader(self._waker_r, self._consume_waker) + + def close(self) -> None: + if self._closed: + return + with self._select_cond: + self._closing_selector = True + self._select_cond.notify() + self._wake_selector() + if self._thread is not None: + self._thread.join() + _selector_loops.discard(self) + self.remove_reader(self._waker_r) + self._waker_r.close() + self._waker_w.close() + self._closed = True + + async def _thread_manager(self) -> typing.AsyncGenerator[None, None]: + # Create a thread to run the select system call. We manage this thread + # manually so we can trigger a clean shutdown from an atexit hook. Note + # that due to the order of operations at shutdown, only daemon threads + # can be shut down in this way (non-daemon threads would require the + # introduction of a new hook: https://bugs.python.org/issue41962) + self._thread = threading.Thread( + name="Tornado selector", + daemon=True, + target=self._run_select, + ) + self._thread.start() + self._start_select() + try: + # The presense of this yield statement means that this coroutine + # is actually an asynchronous generator, which has a special + # shutdown protocol. We wait at this yield point until the + # event loop's shutdown_asyncgens method is called, at which point + # we will get a GeneratorExit exception and can shut down the + # selector thread. + yield + except GeneratorExit: + self.close() + raise + + def _wake_selector(self) -> None: + if self._closed: + return + try: + self._waker_w.send(b"a") + except BlockingIOError: + pass + + def _consume_waker(self) -> None: + try: + self._waker_r.recv(1024) + except BlockingIOError: + pass + + def _start_select(self) -> None: + # Capture reader and writer sets here in the event loop + # thread to avoid any problems with concurrent + # modification while the select loop uses them. + with self._select_cond: + assert self._select_args is None + self._select_args = (list(self._readers.keys()), list(self._writers.keys())) + self._select_cond.notify() + + def _run_select(self) -> None: + while True: + with self._select_cond: + while self._select_args is None and not self._closing_selector: + self._select_cond.wait() + if self._closing_selector: + return + assert self._select_args is not None + to_read, to_write = self._select_args + self._select_args = None + + # We use the simpler interface of the select module instead of + # the more stateful interface in the selectors module because + # this class is only intended for use on windows, where + # select.select is the only option. The selector interface + # does not have well-documented thread-safety semantics that + # we can rely on so ensuring proper synchronization would be + # tricky. + try: + # On windows, selecting on a socket for write will not + # return the socket when there is an error (but selecting + # for reads works). Also select for errors when selecting + # for writes, and merge the results. + # + # This pattern is also used in + # https://github.com/python/cpython/blob/v3.8.0/Lib/selectors.py#L312-L317 + rs, ws, xs = select.select(to_read, to_write, to_write) + ws = ws + xs + except OSError as e: + # After remove_reader or remove_writer is called, the file + # descriptor may subsequently be closed on the event loop + # thread. It's possible that this select thread hasn't + # gotten into the select system call by the time that + # happens in which case (at least on macOS), select may + # raise a "bad file descriptor" error. If we get that + # error, check and see if we're also being woken up by + # polling the waker alone. If we are, just return to the + # event loop and we'll get the updated set of file + # descriptors on the next iteration. Otherwise, raise the + # original error. + if e.errno == getattr(errno, "WSAENOTSOCK", errno.EBADF): + rs, _, _ = select.select([self._waker_r.fileno()], [], [], 0) + if rs: + ws = [] + else: + raise + else: + raise + + try: + self._real_loop.call_soon_threadsafe(self._handle_select, rs, ws) + except RuntimeError: + # "Event loop is closed". Swallow the exception for + # consistency with PollIOLoop (and logical consistency + # with the fact that we can't guarantee that an + # add_callback that completes without error will + # eventually execute). + pass + except AttributeError: + # ProactorEventLoop may raise this instead of RuntimeError + # if call_soon_threadsafe races with a call to close(). + # Swallow it too for consistency. + pass + + def _handle_select( + self, rs: list[_FileDescriptorLike], ws: list[_FileDescriptorLike] + ) -> None: + for r in rs: + self._handle_event(r, self._readers) + for w in ws: + self._handle_event(w, self._writers) + self._start_select() + + def _handle_event( + self, + fd: _FileDescriptorLike, + cb_map: dict[_FileDescriptorLike, Callable], + ) -> None: + try: + callback = cb_map[fd] + except KeyError: + return + callback() + + def add_reader( + self, fd: _FileDescriptorLike, callback: Callable[..., None], *args: Any + ) -> None: + self._readers[fd] = functools.partial(callback, *args) + self._wake_selector() + + def add_writer( + self, fd: _FileDescriptorLike, callback: Callable[..., None], *args: Any + ) -> None: + self._writers[fd] = functools.partial(callback, *args) + self._wake_selector() + + def remove_reader(self, fd: _FileDescriptorLike) -> bool: + try: + del self._readers[fd] + except KeyError: + return False + self._wake_selector() + return True + + def remove_writer(self, fd: _FileDescriptorLike) -> bool: + try: + del self._writers[fd] + except KeyError: + return False + self._wake_selector() + return True + + +# AddThreadSelectorEventLoop: unmodified from tornado 6.4.0 +class AddThreadSelectorEventLoop(asyncio.AbstractEventLoop): + """Wrap an event loop to add implementations of the ``add_reader`` method family. + + Instances of this class start a second thread to run a selector. + This thread is completely hidden from the user; all callbacks are + run on the wrapped event loop's thread. + + This class is used automatically by Tornado; applications should not need + to refer to it directly. + + It is safe to wrap any event loop with this class, although it only makes sense + for event loops that do not implement the ``add_reader`` family of methods + themselves (i.e. ``WindowsProactorEventLoop``) + + Closing the ``AddThreadSelectorEventLoop`` also closes the wrapped event loop. + """ + + # This class is a __getattribute__-based proxy. All attributes other than those + # in this set are proxied through to the underlying loop. + MY_ATTRIBUTES = { + "_real_loop", + "_selector", + "add_reader", + "add_writer", + "close", + "remove_reader", + "remove_writer", + } + + def __getattribute__(self, name: str) -> Any: + if name in AddThreadSelectorEventLoop.MY_ATTRIBUTES: + return super().__getattribute__(name) + return getattr(self._real_loop, name) + + def __init__(self, real_loop: asyncio.AbstractEventLoop) -> None: + self._real_loop = real_loop + self._selector = SelectorThread(real_loop) + + def close(self) -> None: + self._selector.close() + self._real_loop.close() + + def add_reader( # type: ignore[override] + self, fd: _FileDescriptorLike, callback: Callable[..., None], *args: Any + ) -> None: + return self._selector.add_reader(fd, callback, *args) + + def add_writer( # type: ignore[override] + self, fd: _FileDescriptorLike, callback: Callable[..., None], *args: Any + ) -> None: + return self._selector.add_writer(fd, callback, *args) + + def remove_reader(self, fd: _FileDescriptorLike) -> bool: + return self._selector.remove_reader(fd) + + def remove_writer(self, fd: _FileDescriptorLike) -> bool: + return self._selector.remove_writer(fd) + + +# registry of asyncio loop : selector thread +_selectors: WeakKeyDictionary = WeakKeyDictionary() + + +def _set_selector_windows() -> None: + """Set selector-compatible loop. + Sets ``add_reader`` family of methods on the asyncio loop. + Workaround Windows proactor removal of *reader methods. + """ + if not ( + sys.platform == "win32" + and current_async_library() == "asyncio" + and asyncio.get_event_loop_policy().__class__.__name__ + == "WindowsProactorEventLoopPolicy" + ): + return + + asyncio_loop = asyncio.get_running_loop() + if asyncio_loop in _selectors: + return + + from ._selector_thread import AddThreadSelectorEventLoop + + selector_loop = _selectors[asyncio_loop] = AddThreadSelectorEventLoop( # type: ignore[abstract] + asyncio_loop + ) + + # patch loop.close to also close the selector thread + loop_close = asyncio_loop.close + + def _close_selector_and_loop() -> None: + # restore original before calling selector.close, + # which in turn calls eventloop.close! + asyncio_loop.close = loop_close # type: ignore[method-assign] + _selectors.pop(asyncio_loop, None) + selector_loop.close() + + asyncio_loop.close = _close_selector_and_loop # type: ignore[method-assign] + asyncio_loop.add_reader = selector_loop.add_reader # type: ignore[assignment] + asyncio_loop.remove_reader = selector_loop.remove_reader # type: ignore[method-assign] diff --git a/src/zmq_anyio/_socket.py b/src/zmq_anyio/_socket.py index e861ee0..e69186b 100644 --- a/src/zmq_anyio/_socket.py +++ b/src/zmq_anyio/_socket.py @@ -27,6 +27,8 @@ from zmq import EVENTS, POLLIN, POLLOUT from zmq.utils import jsonapi +from ._selector_thread import _set_selector_windows + try: DEFAULT_PROTOCOL = pickle.DEFAULT_PROTOCOL except AttributeError: @@ -887,6 +889,7 @@ async def _start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): if self.started.is_set(): return + _set_selector_windows() assert self.started is not None assert self.stopped is not None assert self._exited is not None diff --git a/tests/test_socket.py b/tests/test_socket.py index d89bc73..76cc73f 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -10,6 +10,22 @@ pytestmark = pytest.mark.anyio +async def test_close1(create_bound_pair): + a, b = map(Socket, create_bound_pair(zmq.PUSH, zmq.PULL)) + with fail_after(1): + async with create_task_group() as tg: + await tg.start(a.start) + await tg.start(b.start) + await sleep(0.4) + print(f"{a.fileno()=}") + print(f"{b.fileno()=}") + a.close() + b.close() + print(f"{a.fileno()=}") + print(f"{b.fileno()=}") + await sleep(0.4) + + async def test_context(context): a, b = Socket(context, zmq.PAIR), Socket(context, zmq.PAIR) port = a.bind_to_random_port("tcp://127.0.0.1") @@ -335,19 +351,20 @@ async def test_poll_on_closed_socket(push_pull): assert f.done() -async def test_close(create_bound_pair): +async def test_close2(create_bound_pair): a, b = map(Socket, create_bound_pair(zmq.PUSH, zmq.PULL)) with fail_after(1): async with create_task_group() as tg: await tg.start(a.start) await tg.start(b.start) + await sleep(0.4) print(f"{a.fileno()=}") print(f"{b.fileno()=}") a.close() b.close() print(f"{a.fileno()=}") print(f"{b.fileno()=}") - await sleep(0.1) + await sleep(0.4) async def test_wait_readable(): From 2499a4eb012c6286130b570ffa7018be57a21fcf Mon Sep 17 00:00:00 2001 From: David Brochart Date: Sun, 16 Mar 2025 17:50:43 +0100 Subject: [PATCH 08/10] Never raise in the selector thread --- src/zmq_anyio/_selector_thread.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/zmq_anyio/_selector_thread.py b/src/zmq_anyio/_selector_thread.py index 68bca92..c64c85b 100644 --- a/src/zmq_anyio/_selector_thread.py +++ b/src/zmq_anyio/_selector_thread.py @@ -216,10 +216,10 @@ def _run_select(self) -> None: rs, _, _ = select.select([self._waker_r.fileno()], [], [], 0) if rs: ws = [] - else: - raise - else: - raise + # else: + # raise + # else: + # raise try: self._real_loop.call_soon_threadsafe(self._handle_select, rs, ws) From 71b5418e70259b3f3efba402b3215b44a5bff151 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Mon, 17 Mar 2025 21:11:47 +0100 Subject: [PATCH 09/10] Use notify_closing --- .github/workflows/test.yml | 4 +++- pyproject.toml | 1 - src/zmq_anyio/_socket.py | 3 --- tests/test_socket.py | 6 +++++- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b4f97e5..671e2fa 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -33,7 +33,9 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Install dependencies - run: pip install -e ".[test]" + run: | + pip install -e ".[test]" + pip install git+https://github.com/agronholm/anyio.git@notify-closing#egg=anyio --ignore-installed - name: Check with mypy and ruff if: ${{ (matrix.python-version == '3.13') && (matrix.os == 'ubuntu-latest') }} run: | diff --git a/pyproject.toml b/pyproject.toml index 019fc3c..b76bcca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,6 @@ requires-python = ">= 3.9" dependencies = [ "anyio >=4.8.0,<5.0.0", "anyioutils >=0.7.1,<0.8.0", - "sniffio", "pyzmq >=26.0.0,<27.0.0", ] diff --git a/src/zmq_anyio/_socket.py b/src/zmq_anyio/_socket.py index e69186b..e861ee0 100644 --- a/src/zmq_anyio/_socket.py +++ b/src/zmq_anyio/_socket.py @@ -27,8 +27,6 @@ from zmq import EVENTS, POLLIN, POLLOUT from zmq.utils import jsonapi -from ._selector_thread import _set_selector_windows - try: DEFAULT_PROTOCOL = pickle.DEFAULT_PROTOCOL except AttributeError: @@ -889,7 +887,6 @@ async def _start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): if self.started.is_set(): return - _set_selector_windows() assert self.started is not None assert self.stopped is not None assert self._exited is not None diff --git a/tests/test_socket.py b/tests/test_socket.py index 76cc73f..23810df 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -3,7 +3,7 @@ import pytest import zmq -from anyio import create_task_group, fail_after, move_on_after, sleep, to_thread, wait_all_tasks_blocked, wait_readable +from anyio import create_task_group, fail_after, move_on_after, notify_closing, sleep, to_thread, wait_all_tasks_blocked, wait_readable from anyioutils import CancelledError, Future, create_task from zmq_anyio import Poller, Socket @@ -19,6 +19,8 @@ async def test_close1(create_bound_pair): await sleep(0.4) print(f"{a.fileno()=}") print(f"{b.fileno()=}") + notify_closing(a) + notify_closing(b) a.close() b.close() print(f"{a.fileno()=}") @@ -360,6 +362,8 @@ async def test_close2(create_bound_pair): await sleep(0.4) print(f"{a.fileno()=}") print(f"{b.fileno()=}") + notify_closing(a) + notify_closing(b) a.close() b.close() print(f"{a.fileno()=}") From 9bec674859c6957a57b923f15ec1c1a4cb9546a5 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Mon, 17 Mar 2025 23:25:14 +0100 Subject: [PATCH 10/10] Stop on ClosedResourceError --- src/zmq_anyio/_socket.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/zmq_anyio/_socket.py b/src/zmq_anyio/_socket.py index e861ee0..6b48f86 100644 --- a/src/zmq_anyio/_socket.py +++ b/src/zmq_anyio/_socket.py @@ -901,20 +901,21 @@ async def _start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): self._task_group, exception_handler=ignore_exceptions, ) + wait_readable_task = create_task( + wait_readable(self), # type: ignore[arg-type] + self._task_group, + exception_handler=self._handle_closed_resource_error, + ) tasks = [ - create_task( - wait_readable(self), # type: ignore[arg-type] - self._task_group, - exception_handler=ignore_exceptions, - ), wait_stopped_task, + wait_readable_task, ] done, pending = await wait( tasks, self._task_group, return_when=FIRST_COMPLETED ) for task in pending: task.cancel() - if wait_stopped_task in done: + if wait_stopped_task in done or self.stopped.is_set(): break await self._handle_events() except BaseException: @@ -922,8 +923,12 @@ async def _start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): finally: self._exited.set() + self.stopped.set() + + def _handle_closed_resource_error(self, exc: BaseException) -> bool: assert self.stopped is not None self.stopped.set() + return True async def stop(self): assert self._exited is not None