diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 42a6f1c..6c0ce29 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,8 +25,6 @@ jobs: check_lint: ['0'] extra_name: [''] include: - - python: '2.7' - extra_name: ', build only' - python: '3.9' check_lint: '1' extra_name: ', check lint' diff --git a/CHANGES.txt b/CHANGES.txt index 96c7638..f643721 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,3 +1,10 @@ +v1.0.0, unreleased + Propagate exceptions raised by the user's packet callback + Warn about exceptions raised by the packet callback during queue unbinding + Raise an error if a packet verdict is set after its parent queue is closed + set_payload() now affects the result of later get_payload() + Handle signals received when run() is blocked in recv() + v0.9.0, 12 Jan 2021 Improve usability when Packet objects are retained past the callback Add Packet.retain() to save the packet contents in such cases diff --git a/README.rst b/README.rst index 8cb73d0..e9b0231 100644 --- a/README.rst +++ b/README.rst @@ -90,13 +90,9 @@ To install from source:: API === -``NetfilterQueue.COPY_NONE`` - -``NetfilterQueue.COPY_META`` - -``NetfilterQueue.COPY_PACKET`` +``NetfilterQueue.COPY_NONE``, ``NetfilterQueue.COPY_META``, ``NetfilterQueue.COPY_PACKET`` These constants specify how much of the packet should be given to the - script- nothing, metadata, or the whole packet. + script: nothing, metadata, or the whole packet. NetfilterQueue objects ---------------------- @@ -104,7 +100,7 @@ NetfilterQueue objects A NetfilterQueue object represents a single queue. Configure your queue with a call to ``bind``, then start receiving packets with a call to ``run``. -``QueueHandler.bind(queue_num, callback[, max_len[, mode[, range[, sock_len]]]])`` +``NetfilterQueue.bind(queue_num, callback, max_len=1024, mode=COPY_PACKET, range=65535, sock_len=...)`` Create and bind to the queue. ``queue_num`` uniquely identifies this queue for the kernel. It must match the ``--queue-num`` in your iptables rule, but there is no ordering requirement: it's fine to either ``bind()`` @@ -118,22 +114,23 @@ a call to ``bind``, then start receiving packets with a call to ``run``. the source and destination IPs of a IPv4 packet, ``range`` could be 20. ``sock_len`` sets the receive socket buffer size. -``QueueHandler.unbind()`` +``NetfilterQueue.unbind()`` Remove the queue. Packets matched by your iptables rule will be dropped. -``QueueHandler.get_fd()`` +``NetfilterQueue.get_fd()`` Get the file descriptor of the socket used to receive queued packets and send verdicts. If you're using an async event loop, you can poll this FD for readability and call ``run(False)`` every time data appears on it. -``QueueHandler.run([block])`` +``NetfilterQueue.run(block=True)`` Send packets to your callback. By default, this method blocks, running until an exception is raised (such as by Ctrl+C). Set - block=False to process the pending messages without waiting for more. - You can get the file descriptor of the socket with the ``get_fd`` method. + ``block=False`` to process the pending messages without waiting for more; + in conjunction with the ``get_fd`` method, you can use this to integrate + with async event loops. -``QueueHandler.run_socket(socket)`` +``NetfilterQueue.run_socket(socket)`` Send packets to your callback, but use the supplied socket instead of recv, so that, for example, gevent can monkeypatch it. You can make a socket with ``socket.fromfd(nfqueue.get_fd(), socket.AF_NETLINK, socket.SOCK_RAW)`` @@ -148,6 +145,8 @@ Objects of this type are passed to your callback. Return the packet's payload as a bytes object. The returned value starts with the IP header. You must call ``retain()`` if you want to be able to ``get_payload()`` after your callback has returned. + If you have already called ``set_payload()``, then ``get_payload()`` + returns what you passed to ``set_payload()``. ``Packet.set_payload(payload)`` Set the packet payload. Call this before ``accept()`` if you want to @@ -166,12 +165,46 @@ Objects of this type are passed to your callback. rules. ``mark`` is a 32-bit number. ``Packet.get_mark()`` - Get the mark already on the packet (either the one you set using + Get the mark on the packet (either the one you set using ``set_mark()``, or the one it arrived with if you haven't called ``set_mark()``). ``Packet.get_hw()`` - Return the hardware address as a Python string. + Return the source hardware address of the packet as a Python + bytestring, or ``None`` if the source hardware address was not + captured (packets captured by the ``OUTPUT`` or ``PREROUTING`` + hooks). For example, on Ethernet the result will be a six-byte + MAC address. The destination hardware address is not available + because it is determined in the kernel only after packet filtering + is complete. + +``Packet.get_timestamp()`` + Return the time at which this packet was received by the kernel, + as a floating-point Unix timestamp with microsecond precision + (comparable to the result of ``time.time()``, for example). + Packets captured by the ``OUTPUT`` or ``POSTROUTING`` hooks + do not have a timestamp, and ``get_timestamp()`` will return 0.0 + for them. + +``Packet.id`` + The identifier assigned to this packet by the kernel. Typically + the first packet received by your queue starts at 1 and later ones + count up from there. + +``Packet.hw_protocol`` + The link-layer protocol for this packet. For example, IPv4 packets + on Ethernet would have this set to the EtherType for IPv4, which is + ``0x0800``. + +``Packet.mark`` + The mark that had been assigned to this packet when it was enqueued. + Unlike the result of ``get_mark()``, this does not change if you call + ``set_mark()``. + +``Packet.hook`` + The netfilter hook (iptables chain, roughly) that diverted this packet + into our queue. Values 0 through 4 correspond to PREROUTING, INPUT, + FORWARD, OUTPUT, and POSTROUTING respectively. ``Packet.retain()`` Allocate a copy of the packet payload for use after the callback @@ -249,20 +282,39 @@ The fields are: Limitations =========== -* Compiled with a 4096-byte buffer for packets, so it probably won't work on - loopback or Ethernet with jumbo packets. If this is a problem, either lower - MTU on your loopback, disable jumbo packets, or get Cython, - change ``DEF BufferSize = 4096`` in ``netfilterqueue.pyx``, and rebuild. -* Full libnetfilter_queue API is not yet implemented: +* We use a fixed-size 4096-byte buffer for packets, so you are likely + to see truncation on loopback and on Ethernet with jumbo packets. + If this is a problem, either lower the MTU on your loopback, disable + jumbo packets, or get Cython, change ``DEF BufferSize = 4096`` in + ``netfilterqueue.pyx``, and rebuild. - * Omits methods for getting information about the interface a packet has - arrived on or is leaving on - * Probably other stuff is omitted too +* Not all information available from libnetfilter_queue is exposed: + missing pieces include packet input/output network interface names, + checksum offload flags, UID/GID and security context data + associated with the packet (if any). + +* Not all information available from the kernel is even processed by + libnetfilter_queue: missing pieces include additional link-layer + header data for some packets (including VLAN tags), connection-tracking + state, and incoming packet length (if truncated for queueing). + +* We do not expose the libnetfilter_queue interface for changing queue flags. + Most of these pertain to other features we don't support (listed above), + but there's one that could set the queue to accept (rather than dropping) + packets received when it's full. Source ====== -https://github.com/kti/python-netfilterqueue +https://github.com/oremanj/python-netfilterqueue + +Authorship +========== + +python-netfilterqueue was originally written by Matthew Fox of +Kerkhoff Technologies, Inc. Since 2022 it has been maintained by +Joshua Oreman of Hudson River Trading LLC. Both authors wish to +thank their employers for their support of open source. License ======= diff --git a/ci.sh b/ci.sh index 0284cdc..d836d17 100755 --- a/ci.sh +++ b/ci.sh @@ -13,12 +13,6 @@ python setup.py sdist --formats=zip pip uninstall -y cython pip install dist/*.zip -if python --version 2>&1 | fgrep -q "Python 2.7"; then - # The testsuite doesn't run on 2.7, so do just a basic smoke test. - unshare -Urn python -c "from netfilterqueue import NetfilterQueue as NFQ; NFQ()" - exit $? -fi - pip install -Ur test-requirements.txt if [ "$CHECK_LINT" = "1" ]; then diff --git a/netfilterqueue.pxd b/netfilterqueue.pxd index e82d904..f00bce1 100644 --- a/netfilterqueue.pxd +++ b/netfilterqueue.pxd @@ -8,6 +8,7 @@ cdef extern from "": # dummy defines from asm-generic/errno.h: cdef enum: + EINTR = 4 EAGAIN = 11 # Try again EWOULDBLOCK = EAGAIN ENOBUFS = 105 # No buffer space available @@ -115,15 +116,17 @@ cdef extern from "libnetfilter_queue/libnetfilter_queue.h": u_int16_t num, nfq_callback *cb, void *data) - int nfq_destroy_queue(nfq_q_handle *qh) - int nfq_handle_packet(nfq_handle *h, char *buf, int len) - - int nfq_set_mode(nfq_q_handle *qh, - u_int8_t mode, unsigned int len) - - q_set_queue_maxlen(nfq_q_handle *qh, - u_int32_t queuelen) + # Any function that parses Netlink replies might invoke the user + # callback and thus might need to propagate a Python exception. + # This includes nfq_handle_packet but is not limited to that -- + # other functions might send a query, read until they get the reply, + # and find a packet notification before the reply which they then + # must deal with. + int nfq_destroy_queue(nfq_q_handle *qh) except? -1 + int nfq_handle_packet(nfq_handle *h, char *buf, int len) except? -1 + int nfq_set_mode(nfq_q_handle *qh, u_int8_t mode, unsigned int len) except? -1 + int nfq_set_queue_maxlen(nfq_q_handle *qh, u_int32_t queuelen) except? -1 int nfq_set_verdict(nfq_q_handle *qh, u_int32_t id, @@ -137,7 +140,6 @@ cdef extern from "libnetfilter_queue/libnetfilter_queue.h": u_int32_t mark, u_int32_t datalen, unsigned char *buf) nogil - int nfq_set_queue_maxlen(nfq_q_handle *qh, u_int32_t queuelen) int nfq_fd(nfq_handle *h) nfqnl_msg_packet_hdr *nfq_get_msg_packet_hdr(nfq_data *nfad) @@ -146,7 +148,7 @@ cdef extern from "libnetfilter_queue/libnetfilter_queue.h": nfqnl_msg_packet_hw *nfq_get_packet_hw(nfq_data *nfad) int nfq_get_nfmark (nfq_data *nfad) nfnl_handle *nfq_nfnlh(nfq_handle *h) - + # Dummy defines from linux/socket.h: cdef enum: # Protocol families, same as address families. PF_INET = 2 @@ -166,12 +168,19 @@ cdef enum: NF_STOP NF_MAX_VERDICT = NF_STOP +cdef class NetfilterQueue: + cdef object __weakref__ + cdef object user_callback # User callback + cdef nfq_handle *h # Handle to NFQueue library + cdef nfq_q_handle *qh # A handle to the queue + cdef class Packet: - cdef nfq_q_handle *_qh + cdef NetfilterQueue _queue cdef bint _verdict_is_set # True if verdict has been issued, # false otherwise cdef bint _mark_is_set # True if a mark has been given, false otherwise cdef bint _hwaddr_is_set + cdef bint _timestamp_is_set cdef u_int32_t _given_mark # Mark given to packet cdef bytes _given_payload # New payload of packet, or null cdef bytes _owned_payload @@ -189,16 +198,15 @@ cdef class Packet: cdef u_int8_t hw_addr[8] # TODO: implement these - #cdef u_int8_t hw_addr[8] # A eui64-formatted address? #cdef readonly u_int32_t nfmark #cdef readonly u_int32_t indev #cdef readonly u_int32_t physindev #cdef readonly u_int32_t outdev #cdef readonly u_int32_t physoutdev - cdef set_nfq_data(self, nfq_q_handle *qh, nfq_data *nfa) + cdef set_nfq_data(self, NetfilterQueue queue, nfq_data *nfa) cdef drop_refs(self) - cdef void verdict(self, u_int8_t verdict) + cdef int verdict(self, u_int8_t verdict) except -1 cpdef Py_ssize_t get_payload_len(self) cpdef double get_timestamp(self) cpdef bytes get_payload(self) @@ -209,11 +217,3 @@ cdef class Packet: cpdef accept(self) cpdef drop(self) cpdef repeat(self) - -cdef class NetfilterQueue: - cdef object user_callback # User callback - cdef nfq_handle *h # Handle to NFQueue library - cdef nfq_q_handle *qh # A handle to the queue - cdef u_int16_t af # Address family - cdef packet_copy_size # Amount of packet metadata + data copied to buffer - diff --git a/netfilterqueue.pyx b/netfilterqueue.pyx index 42a4a98..f917716 100644 --- a/netfilterqueue.pyx +++ b/netfilterqueue.pyx @@ -24,22 +24,26 @@ DEF SockCopySize = MaxCopySize + SockOverhead # Socket queue should hold max number of packets of copysize bytes DEF SockRcvSize = DEFAULT_MAX_QUEUELEN * SockCopySize // 2 -cdef extern from *: - """ - #if PY_MAJOR_VERSION < 3 - #define PyBytes_FromStringAndSize PyString_FromStringAndSize - #endif - """ +from cpython.exc cimport PyErr_CheckSignals +# A negative return value from this callback will stop processing and +# make nfq_handle_packet return -1, so we use that as the error flag. cdef int global_callback(nfq_q_handle *qh, nfgenmsg *nfmsg, - nfq_data *nfa, void *data) with gil: + nfq_data *nfa, void *data) except -1 with gil: """Create a Packet and pass it to appropriate callback.""" cdef NetfilterQueue nfqueue = data cdef object user_callback = nfqueue.user_callback + if user_callback is None: + # Queue is being unbound; we can't send a verdict at this point + # so just ignore the packet. The kernel will drop it once we + # unbind. + return 1 packet = Packet() - packet.set_nfq_data(qh, nfa) - user_callback(packet) - packet.drop_refs() + packet.set_nfq_data(nfqueue, nfa) + try: + user_callback(packet) + finally: + packet.drop_refs() return 1 cdef class Packet: @@ -54,7 +58,7 @@ cdef class Packet: protocol = PROTOCOLS.get(hdr.protocol, "Unknown protocol") return "%s packet, %s bytes" % (protocol, self.payload_len) - cdef set_nfq_data(self, nfq_q_handle *qh, nfq_data *nfa): + cdef set_nfq_data(self, NetfilterQueue queue, nfq_data *nfa): """ Assign a packet from NFQ to this object. Parse the header and load local values. @@ -63,7 +67,7 @@ cdef class Packet: cdef nfqnl_msg_packet_hdr *hdr hdr = nfq_get_msg_packet_hdr(nfa) - self._qh = qh + self._queue = queue self.id = ntohl(hdr.packet_id) self.hw_protocol = ntohs(hdr.hw_protocol) self.hook = hdr.hook @@ -78,7 +82,9 @@ cdef class Packet: self.payload_len = nfq_get_payload(nfa, &self.payload) if self.payload_len < 0: - raise OSError("Failed to get payload of packet.") + # Probably using a mode that doesn't provide the payload + self.payload = NULL + self.payload_len = 0 nfq_get_timestamp(nfa, &self.timestamp) self.mark = nfq_get_nfmark(nfa) @@ -86,14 +92,16 @@ cdef class Packet: cdef drop_refs(self): """ Called at the end of the user_callback, when the storage passed to - set_nfq_data() is about to be deallocated. + set_nfq_data() is about to be reused. """ self.payload = NULL - cdef void verdict(self, u_int8_t verdict): + cdef int verdict(self, u_int8_t verdict) except -1: """Call appropriate set_verdict... function on packet.""" if self._verdict_is_set: - raise RuntimeWarning("Verdict already given for this packet.") + raise RuntimeError("Verdict already given for this packet") + if self._queue.qh == NULL: + raise RuntimeError("Parent queue has already been unbound") cdef u_int32_t modified_payload_len = 0 cdef unsigned char *modified_payload = NULL @@ -102,7 +110,7 @@ cdef class Packet: modified_payload = self._given_payload if self._mark_is_set: nfq_set_verdict2( - self._qh, + self._queue.qh, self.id, verdict, self._given_mark, @@ -110,7 +118,7 @@ cdef class Packet: modified_payload) else: nfq_set_verdict( - self._qh, + self._queue.qh, self.id, verdict, modified_payload_len, @@ -119,17 +127,27 @@ cdef class Packet: self._verdict_is_set = True def get_hw(self): - """Return the hardware address as Python string.""" + """Return the packet's source MAC address as a Python bytestring, or + None if it's not available. + """ + if not self._hwaddr_is_set: + return None cdef object py_string py_string = PyBytes_FromStringAndSize(self.hw_addr, 8) return py_string cpdef bytes get_payload(self): """Return payload as Python string.""" - if self._owned_payload: + if self._given_payload: + return self._given_payload + elif self._owned_payload: return self._owned_payload elif self.payload != NULL: return self.payload[:self.payload_len] + elif self.payload_len == 0: + raise RuntimeError( + "Packet has no payload -- perhaps you're using COPY_META mode?" + ) else: raise RuntimeError( "Payload data is no longer available. You must call " @@ -172,25 +190,31 @@ cdef class Packet: """Repeat the packet.""" self.verdict(NF_REPEAT) + cdef class NetfilterQueue: """Handle a single numbered queue.""" def __cinit__(self, *args, **kwargs): - self.af = kwargs.get("af", PF_INET) + cdef u_int16_t af # Address family + af = kwargs.get("af", PF_INET) self.h = nfq_open() if self.h == NULL: raise OSError("Failed to open NFQueue.") - nfq_unbind_pf(self.h, self.af) # This does NOT kick out previous - # running queues - if nfq_bind_pf(self.h, self.af) < 0: - raise OSError("Failed to bind family %s. Are you root?" % self.af) + nfq_unbind_pf(self.h, af) # This does NOT kick out previous queues + if nfq_bind_pf(self.h, af) < 0: + raise OSError("Failed to bind family %s. Are you root?" % af) + + def __del__(self): + # unbind() can result in invocations of global_callback, so we + # must do it in __del__ (when this is still a valid + # NetfilterQueue object) rather than __dealloc__ + self.unbind() def __dealloc__(self): - if self.qh != NULL: - nfq_destroy_queue(self.qh) # Don't call nfq_unbind_pf unless you want to disconnect any other # processes using this libnetfilter_queue on this protocol family! - nfq_close(self.h) + if self.h != NULL: + nfq_close(self.h) def bind(self, int queue_num, object user_callback, u_int32_t max_len=DEFAULT_MAX_QUEUELEN, @@ -231,6 +255,7 @@ cdef class NetfilterQueue: def unbind(self): """Destroy the queue.""" + self.user_callback = None if self.qh != NULL: nfq_destroy_queue(self.qh) self.qh = NULL @@ -251,11 +276,17 @@ cdef class NetfilterQueue: while True: with nogil: rv = recv(fd, buf, sizeof(buf), recv_flags) - if (rv >= 0): - nfq_handle_packet(self.h, buf, rv) - else: - if errno != ENOBUFS: + if rv < 0: + if errno == EAGAIN: break + if errno == ENOBUFS: + # Kernel is letting us know we dropped a packet + continue + if errno == EINTR: + PyErr_CheckSignals() + continue + raise OSError(errno, "recv failed") + nfq_handle_packet(self.h, buf, rv) def run_socket(self, s): """Accept packets using socket.recv so that, for example, gevent can monkeypatch it.""" @@ -264,11 +295,6 @@ cdef class NetfilterQueue: while True: try: buf = s.recv(BufferSize) - rv = len(buf) - if rv >= 0: - nfq_handle_packet(self.h, buf, rv) - else: - break except socket.error as e: err = e.args[0] if err == ENOBUFS: @@ -280,6 +306,8 @@ cdef class NetfilterQueue: else: # This is bad. Let the caller handle it. raise e + else: + nfq_handle_packet(self.h, buf, len(buf)) PROTOCOLS = { 0: "HOPOPT", diff --git a/setup.py b/setup.py index ead5b9d..ad7d22d 100644 --- a/setup.py +++ b/setup.py @@ -3,6 +3,7 @@ from setuptools import setup, Extension VERSION = "0.9.0" # Remember to change CHANGES.txt and netfilterqueue.pyx when version changes. +setup_requires = [] try: # Use Cython from Cython.Build import cythonize @@ -15,7 +16,13 @@ try: ) except ImportError: # No Cython - if not os.path.exists(os.path.join(os.path.dirname(__file__), "netfilterqueue.c")): + if "egg_info" in sys.argv: + # We're being run by pip to figure out what we need. Request cython in + # setup_requires below. + setup_requires = ["cython"] + elif not os.path.exists( + os.path.join(os.path.dirname(__file__), "netfilterqueue.c") + ): sys.stderr.write( "You must have Cython installed (`pip install cython`) to build this " "package from source.\nIf you're receiving this error when installing from " @@ -29,6 +36,8 @@ except ImportError: setup( ext_modules=ext_modules, + setup_requires=setup_requires, + python_requires=">=3.6", name="NetfilterQueue", version=VERSION, license="MIT", diff --git a/tests/conftest.py b/tests/conftest.py index 1c5478f..0d94e9e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,7 +7,8 @@ import sys import trio import unshare import netfilterqueue -from typing import AsyncIterator +from functools import partial +from typing import AsyncIterator, Callable, Optional, Tuple from async_generator import asynccontextmanager from pytest_trio.enable_trio_mode import * @@ -93,7 +94,7 @@ async def peer_main(idx: int, parent_fd: int) -> None: # Enter the message-forwarding loop async def proxy_one_way(src, dest): - while True: + while src.fileno() >= 0: try: msg = await src.recv(4096) except trio.ClosedResourceError: @@ -111,10 +112,19 @@ async def peer_main(idx: int, parent_fd: int) -> None: nursery.start_soon(proxy_one_way, peer, parent) +def _default_capture_cb( + target: "trio.MemorySendChannel[netfilterqueue.Packet]", + packet: netfilterqueue.Packet, +) -> None: + packet.retain() + target.send_nowait(packet) + + class Harness: def __init__(self): self._received = {} self._conn = {} + self.dest_addr = {} self.failed = False async def _run_peer(self, idx: int, *, task_status): @@ -155,7 +165,9 @@ class Harness: "peer subprocess exited with code {}".format(retval) ) finally: - await trio.run_process(f"ip link delete veth{idx}".split()) + # On some kernels the veth device is removed when the subprocess exits + # and its netns goes away. check=False to suppress that error. + await trio.run_process(f"ip link delete veth{idx}".split(), check=False) async def _manage_peer(self, idx: int, *, task_status): async with trio.open_nursery() as nursery: @@ -177,8 +189,12 @@ class Harness: start_nursery.start_soon(nursery.start, self._manage_peer, 1) start_nursery.start_soon(nursery.start, self._manage_peer, 2) # Tell each peer about the other one's port - await self._conn[2].send(await self._received[1].receive()) - await self._conn[1].send(await self._received[2].receive()) + for idx in (1, 2): + self.dest_addr[idx] = ( + PEER_IP[idx], + int(await self._received[idx].receive()), + ) + await self._conn[3 - idx].send(b"%d" % self.dest_addr[idx][1]) yield self._conn[1].shutdown(socket.SHUT_WR) self._conn[2].shutdown(socket.SHUT_WR) @@ -190,26 +206,22 @@ class Harness: f"Peer {idx} received unexepcted packet {remainder!r}" ) - @asynccontextmanager - async def capture_packets_to( - self, idx: int, *, queue_num: int = -1, **options - ) -> AsyncIterator["trio.MemoryReceiveChannel[netfilterqueue.Packet]"]: - - packets_w, packets_r = trio.open_memory_channel(math.inf) - - def stash_packet(p): - p.retain() - packets_w.send_nowait(p) - + def bind_queue( + self, + cb: Callable[[netfilterqueue.Packet], None], + *, + queue_num: int = -1, + **options: int, + ) -> Tuple[int, netfilterqueue.NetfilterQueue]: nfq = netfilterqueue.NetfilterQueue() # Use a smaller socket buffer to avoid a warning in CI options.setdefault("sock_len", 131072) if queue_num >= 0: - nfq.bind(queue_num, stash_packet, **options) + nfq.bind(queue_num, cb, **options) else: for queue_num in range(16): try: - nfq.bind(queue_num, stash_packet, **options) + nfq.bind(queue_num, cb, **options) break except Exception as ex: last_error = ex @@ -217,10 +229,39 @@ class Harness: raise RuntimeError( "Couldn't bind any netfilter queue number between 0-15" ) from last_error + return queue_num, nfq + + @asynccontextmanager + async def enqueue_packets_to( + self, idx: int, queue_num: int, *, forwarded: bool = True + ) -> AsyncIterator[None]: + if forwarded: + chain = "FORWARD" + else: + chain = "OUTPUT" + + rule = f"{chain} -d {PEER_IP[idx]} -j NFQUEUE --queue-num {queue_num}" + await trio.run_process(f"/sbin/iptables -A {rule}".split()) try: - rule = f"-d {PEER_IP[idx]} -j NFQUEUE --queue-num {queue_num}" - await trio.run_process(f"/sbin/iptables -A FORWARD {rule}".split()) - try: + yield + finally: + await trio.run_process(f"/sbin/iptables -D {rule}".split()) + + @asynccontextmanager + async def capture_packets_to( + self, + idx: int, + cb: Callable[ + ["trio.MemorySendChannel[netfilterqueue.Packet]", netfilterqueue.Packet], + None, + ] = _default_capture_cb, + **options: int, + ) -> AsyncIterator["trio.MemoryReceiveChannel[netfilterqueue.Packet]"]: + + packets_w, packets_r = trio.open_memory_channel(math.inf) + queue_num, nfq = self.bind_queue(partial(cb, packets_w), **options) + try: + async with self.enqueue_packets_to(idx, queue_num): async with packets_w, trio.open_nursery() as nursery: @nursery.start_soon @@ -231,8 +272,6 @@ class Harness: yield packets_r nursery.cancel_scope.cancel() - finally: - await trio.run_process(f"/sbin/iptables -D FORWARD {rule}".split()) finally: nfq.unbind() diff --git a/tests/test_basic.py b/tests/test_basic.py index 264aa8d..fd1842a 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,6 +1,15 @@ +import gc import struct import trio +import trio.testing import pytest +import signal +import socket +import sys +import time +import weakref + +from netfilterqueue import NetfilterQueue, COPY_META async def test_comms_without_queue(harness): @@ -61,6 +70,7 @@ async def test_rewrite_reorder(harness): payload = packet.get_payload()[28:] if payload == b"one": set_udp_payload(packet, b"numero uno") + assert b"numero uno" == packet.get_payload()[28:] packet.accept() elif payload == b"two": two = packet @@ -78,11 +88,83 @@ async def test_rewrite_reorder(harness): await harness.expect(2, b"numero uno", b"three", b"TWO", b"four") +async def test_mark_repeat(harness): + counter = 0 + timestamps = [] + + def cb(chan, pkt): + nonlocal counter + with pytest.raises(RuntimeError, match="Packet has no payload"): + pkt.get_payload() + assert pkt.get_mark() == counter + timestamps.append(pkt.get_timestamp()) + if counter < 5: + counter += 1 + pkt.set_mark(counter) + pkt.repeat() + assert pkt.get_mark() == counter + else: + pkt.accept() + + async with harness.capture_packets_to(2, cb, mode=COPY_META): + t0 = time.time() + await harness.send(2, b"testing") + await harness.expect(2, b"testing") + t1 = time.time() + assert counter == 5 + # All iterations of the packet have the same timestamps + assert all(t == timestamps[0] for t in timestamps[1:]) + assert t0 < timestamps[0] < t1 + + +async def test_hwaddr(harness): + hwaddrs = [] + + def cb(pkt): + hwaddrs.append((pkt.get_hw(), pkt.hook, pkt.get_payload()[28:])) + pkt.accept() + + queue_num, nfq = harness.bind_queue(cb) + try: + async with trio.open_nursery() as nursery: + + @nursery.start_soon + async def listen_for_packets(): + while True: + await trio.lowlevel.wait_readable(nfq.get_fd()) + nfq.run(block=False) + + async with harness.enqueue_packets_to(2, queue_num, forwarded=True): + await harness.send(2, b"one", b"two") + await harness.expect(2, b"one", b"two") + async with harness.enqueue_packets_to(2, queue_num, forwarded=False): + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock: + for payload in (b"three", b"four"): + sock.sendto(payload, harness.dest_addr[2]) + with trio.fail_after(1): + while len(hwaddrs) < 4: + await trio.sleep(0.1) + nursery.cancel_scope.cancel() + finally: + nfq.unbind() + + # Forwarded packets capture a hwaddr, but OUTPUT don't + FORWARD = 2 + OUTPUT = 3 + mac1 = hwaddrs[0][0] + assert mac1 is not None + assert hwaddrs == [ + (mac1, FORWARD, b"one"), + (mac1, FORWARD, b"two"), + (None, OUTPUT, b"three"), + (None, OUTPUT, b"four"), + ] + + async def test_errors(harness): with pytest.warns(RuntimeWarning, match="rcvbuf limit is") as record: async with harness.capture_packets_to(2, sock_len=2 ** 30): pass - assert record[0].filename.endswith("conftest.py") async with harness.capture_packets_to(2, queue_num=0): @@ -90,9 +172,94 @@ async def test_errors(harness): async with harness.capture_packets_to(2, queue_num=0): pass - from netfilterqueue import NetfilterQueue + _, nfq = harness.bind_queue(lambda: None, queue_num=1) + with pytest.raises(RuntimeError, match="A queue is already bound"): + nfq.bind(2, lambda p: None) + # Test unbinding via __del__ + nfq = weakref.ref(nfq) + for _ in range(4): + gc.collect() + if nfq() is None: + break + else: + raise RuntimeError("Couldn't trigger garbage collection of NFQ") + + +async def test_unretained(harness): + def cb(chan, pkt): + # Can access payload within callback + assert pkt.get_payload()[-3:] in (b"one", b"two") + chan.send_nowait(pkt) + + # Capture packets without retaining -> can't access payload after cb returns + async with harness.capture_packets_to(2, cb) as chan: + await harness.send(2, b"one", b"two") + accept = True + async for p in chan: + with pytest.raises( + RuntimeError, match="Payload data is no longer available" + ): + p.get_payload() + # Can still issue verdicts though + if accept: + p.accept() + accept = False + else: + break + + with pytest.raises(RuntimeError, match="Parent queue has already been unbound"): + p.drop() + await harness.expect(2, b"one") + + +async def test_cb_exception(harness): + pkt = None + + def cb(channel, p): + nonlocal pkt + pkt = p + raise ValueError("test") + + # Error raised within run(): + with pytest.raises(ValueError, match="test"): + async with harness.capture_packets_to(2, cb): + await harness.send(2, b"boom") + with trio.fail_after(1): + try: + await trio.sleep_forever() + finally: + # At this point the error has been raised (since we were + # cancelled) but the queue is still open. We shouldn't + # be able to access the payload, since we didn't retain(), + # but verdicts should otherwise work. + with pytest.raises(RuntimeError, match="Payload data is no longer"): + pkt.get_payload() + pkt.accept() + + await harness.expect(2, b"boom") + + with pytest.raises(RuntimeError, match="Verdict already given for this packet"): + pkt.drop() + + +@pytest.mark.skipif( + sys.implementation.name == "pypy", + reason="pypy does not support PyErr_CheckSignals", +) +def test_signal(): nfq = NetfilterQueue() nfq.bind(1, lambda p: None, sock_len=131072) - with pytest.raises(RuntimeError, match="A queue is already bound"): - nfq.bind(2, lambda p: None, sock_len=131072) + + def raise_alarm(sig, frame): + raise KeyboardInterrupt("brrrrrring!") + + old_handler = signal.signal(signal.SIGALRM, raise_alarm) + old_timer = signal.setitimer(signal.ITIMER_REAL, 0.5, 0) + try: + with pytest.raises(KeyboardInterrupt, match="brrrrrring!") as exc_info: + nfq.run() + assert any("NetfilterQueue.run" in line.name for line in exc_info.traceback) + finally: + signal.setitimer(signal.ITIMER_REAL, *old_timer) + signal.signal(signal.SIGALRM, old_handler)