Fixes for several open issues

Propagate exceptions raised by the user's packet callback -- fixes #31, #50
Warn about exceptions raised by the packet callback during queue unbinding
Raise an error if a packet verdict is set after its parent queue is closed
set_payload() now affects the result of later get_payload() -- fixes #30
Handle signals received when run() is blocked in recv() -- fixes #65
This commit is contained in:
Joshua Oreman 2022-01-13 03:14:43 -07:00
parent ddbc12a6ab
commit 0187c89611
6 changed files with 240 additions and 61 deletions

View File

@ -1,3 +1,10 @@
v1.0.0, unreleased
Propagate exceptions raised by the user's packet callback
Warn about exceptions raised by the packet callback during queue unbinding
Raise an error if a packet verdict is set after its parent queue is closed
set_payload() now affects the result of later get_payload()
Handle signals received when run() is blocked in recv()
v0.9.0, 12 Jan 2021
Improve usability when Packet objects are retained past the callback
Add Packet.retain() to save the packet contents in such cases

View File

@ -8,6 +8,7 @@ cdef extern from "<errno.h>":
# dummy defines from asm-generic/errno.h:
cdef enum:
EINTR = 4
EAGAIN = 11 # Try again
EWOULDBLOCK = EAGAIN
ENOBUFS = 105 # No buffer space available
@ -115,15 +116,17 @@ cdef extern from "libnetfilter_queue/libnetfilter_queue.h":
u_int16_t num,
nfq_callback *cb,
void *data)
int nfq_destroy_queue(nfq_q_handle *qh)
int nfq_handle_packet(nfq_handle *h, char *buf, int len)
int nfq_set_mode(nfq_q_handle *qh,
u_int8_t mode, unsigned int len)
q_set_queue_maxlen(nfq_q_handle *qh,
u_int32_t queuelen)
# Any function that parses Netlink replies might invoke the user
# callback and thus might need to propagate a Python exception.
# This includes nfq_handle_packet but is not limited to that --
# other functions might send a query, read until they get the reply,
# and find a packet notification before the reply which they then
# must deal with.
int nfq_destroy_queue(nfq_q_handle *qh) except? -1
int nfq_handle_packet(nfq_handle *h, char *buf, int len) except? -1
int nfq_set_mode(nfq_q_handle *qh, u_int8_t mode, unsigned int len) except? -1
int nfq_set_queue_maxlen(nfq_q_handle *qh, u_int32_t queuelen) except? -1
int nfq_set_verdict(nfq_q_handle *qh,
u_int32_t id,
@ -137,7 +140,6 @@ cdef extern from "libnetfilter_queue/libnetfilter_queue.h":
u_int32_t mark,
u_int32_t datalen,
unsigned char *buf) nogil
int nfq_set_queue_maxlen(nfq_q_handle *qh, u_int32_t queuelen)
int nfq_fd(nfq_handle *h)
nfqnl_msg_packet_hdr *nfq_get_msg_packet_hdr(nfq_data *nfad)
@ -146,7 +148,7 @@ cdef extern from "libnetfilter_queue/libnetfilter_queue.h":
nfqnl_msg_packet_hw *nfq_get_packet_hw(nfq_data *nfad)
int nfq_get_nfmark (nfq_data *nfad)
nfnl_handle *nfq_nfnlh(nfq_handle *h)
# Dummy defines from linux/socket.h:
cdef enum: # Protocol families, same as address families.
PF_INET = 2
@ -166,8 +168,14 @@ cdef enum:
NF_STOP
NF_MAX_VERDICT = NF_STOP
cdef class NetfilterQueue:
cdef object user_callback # User callback
cdef nfq_handle *h # Handle to NFQueue library
cdef nfq_q_handle *qh # A handle to the queue
cdef bint unbinding
cdef class Packet:
cdef nfq_q_handle *_qh
cdef NetfilterQueue _queue
cdef bint _verdict_is_set # True if verdict has been issued,
# false otherwise
cdef bint _mark_is_set # True if a mark has been given, false otherwise
@ -196,9 +204,9 @@ cdef class Packet:
#cdef readonly u_int32_t outdev
#cdef readonly u_int32_t physoutdev
cdef set_nfq_data(self, nfq_q_handle *qh, nfq_data *nfa)
cdef set_nfq_data(self, NetfilterQueue queue, nfq_data *nfa)
cdef drop_refs(self)
cdef void verdict(self, u_int8_t verdict)
cdef int verdict(self, u_int8_t verdict) except -1
cpdef Py_ssize_t get_payload_len(self)
cpdef double get_timestamp(self)
cpdef bytes get_payload(self)
@ -209,11 +217,3 @@ cdef class Packet:
cpdef accept(self)
cpdef drop(self)
cpdef repeat(self)
cdef class NetfilterQueue:
cdef object user_callback # User callback
cdef nfq_handle *h # Handle to NFQueue library
cdef nfq_q_handle *qh # A handle to the queue
cdef u_int16_t af # Address family
cdef packet_copy_size # Amount of packet metadata + data copied to buffer

View File

@ -26,20 +26,38 @@ DEF SockRcvSize = DEFAULT_MAX_QUEUELEN * SockCopySize // 2
cdef extern from *:
"""
#if PY_MAJOR_VERSION < 3
#define PyBytes_FromStringAndSize PyString_FromStringAndSize
#endif
static void do_write_unraisable(PyObject* obj) {
PyObject *ty, *val, *tb;
PyErr_GetExcInfo(&ty, &val, &tb);
PyErr_Restore(ty, val, tb);
PyErr_WriteUnraisable(obj);
}
"""
cdef void do_write_unraisable(msg)
from cpython.exc cimport PyErr_CheckSignals
# A negative return value from this callback will stop processing and
# make nfq_handle_packet return -1, so we use that as the error flag.
cdef int global_callback(nfq_q_handle *qh, nfgenmsg *nfmsg,
nfq_data *nfa, void *data) with gil:
nfq_data *nfa, void *data) except -1 with gil:
"""Create a Packet and pass it to appropriate callback."""
cdef NetfilterQueue nfqueue = <NetfilterQueue>data
cdef object user_callback = <object>nfqueue.user_callback
packet = Packet()
packet.set_nfq_data(qh, nfa)
user_callback(packet)
packet.drop_refs()
packet.set_nfq_data(nfqueue, nfa)
try:
user_callback(packet)
except BaseException as exc:
if nfqueue.unbinding == True:
do_write_unraisable(
"netfilterqueue callback during unbind"
)
else:
raise
finally:
packet.drop_refs()
return 1
cdef class Packet:
@ -54,7 +72,7 @@ cdef class Packet:
protocol = PROTOCOLS.get(hdr.protocol, "Unknown protocol")
return "%s packet, %s bytes" % (protocol, self.payload_len)
cdef set_nfq_data(self, nfq_q_handle *qh, nfq_data *nfa):
cdef set_nfq_data(self, NetfilterQueue queue, nfq_data *nfa):
"""
Assign a packet from NFQ to this object. Parse the header and load
local values.
@ -63,7 +81,7 @@ cdef class Packet:
cdef nfqnl_msg_packet_hdr *hdr
hdr = nfq_get_msg_packet_hdr(nfa)
self._qh = qh
self._queue = queue
self.id = ntohl(hdr.packet_id)
self.hw_protocol = ntohs(hdr.hw_protocol)
self.hook = hdr.hook
@ -90,10 +108,12 @@ cdef class Packet:
"""
self.payload = NULL
cdef void verdict(self, u_int8_t verdict):
cdef int verdict(self, u_int8_t verdict) except -1:
"""Call appropriate set_verdict... function on packet."""
if self._verdict_is_set:
raise RuntimeWarning("Verdict already given for this packet.")
raise RuntimeError("Verdict already given for this packet")
if self._queue.qh == NULL:
raise RuntimeError("Parent queue has already been unbound")
cdef u_int32_t modified_payload_len = 0
cdef unsigned char *modified_payload = NULL
@ -102,7 +122,7 @@ cdef class Packet:
modified_payload = self._given_payload
if self._mark_is_set:
nfq_set_verdict2(
self._qh,
self._queue.qh,
self.id,
verdict,
self._given_mark,
@ -110,7 +130,7 @@ cdef class Packet:
modified_payload)
else:
nfq_set_verdict(
self._qh,
self._queue.qh,
self.id,
verdict,
modified_payload_len,
@ -126,7 +146,9 @@ cdef class Packet:
cpdef bytes get_payload(self):
"""Return payload as Python string."""
if self._owned_payload:
if self._given_payload:
return self._given_payload
elif self._owned_payload:
return self._owned_payload
elif self.payload != NULL:
return self.payload[:self.payload_len]
@ -172,22 +194,23 @@ cdef class Packet:
"""Repeat the packet."""
self.verdict(NF_REPEAT)
cdef class NetfilterQueue:
"""Handle a single numbered queue."""
def __cinit__(self, *args, **kwargs):
self.af = kwargs.get("af", PF_INET)
cdef u_int16_t af # Address family
af = kwargs.get("af", PF_INET)
self.unbinding = False
self.h = nfq_open()
if self.h == NULL:
raise OSError("Failed to open NFQueue.")
nfq_unbind_pf(self.h, self.af) # This does NOT kick out previous
# running queues
if nfq_bind_pf(self.h, self.af) < 0:
raise OSError("Failed to bind family %s. Are you root?" % self.af)
nfq_unbind_pf(self.h, af) # This does NOT kick out previous queues
if nfq_bind_pf(self.h, af) < 0:
raise OSError("Failed to bind family %s. Are you root?" % af)
def __dealloc__(self):
if self.qh != NULL:
nfq_destroy_queue(self.qh)
self.unbind()
# Don't call nfq_unbind_pf unless you want to disconnect any other
# processes using this libnetfilter_queue on this protocol family!
nfq_close(self.h)
@ -232,7 +255,11 @@ cdef class NetfilterQueue:
def unbind(self):
"""Destroy the queue."""
if self.qh != NULL:
nfq_destroy_queue(self.qh)
self.unbinding = True
try:
nfq_destroy_queue(self.qh)
finally:
self.unbinding = False
self.qh = NULL
# See warning about nfq_unbind_pf in __dealloc__ above.
@ -251,11 +278,19 @@ cdef class NetfilterQueue:
while True:
with nogil:
rv = recv(fd, buf, sizeof(buf), recv_flags)
if (rv >= 0):
nfq_handle_packet(self.h, buf, rv)
else:
if errno != ENOBUFS:
if rv < 0:
if errno == EAGAIN:
break
if errno == ENOBUFS:
# Kernel is letting us know we dropped a packet
continue
if errno == EINTR:
PyErr_CheckSignals()
continue
raise OSError(errno, "recv failed")
rv = nfq_handle_packet(self.h, buf, rv)
if rv < 0:
raise OSError(errno, "nfq_handle_packet failed")
def run_socket(self, s):
"""Accept packets using socket.recv so that, for example, gevent can monkeypatch it."""

View File

@ -7,6 +7,7 @@ try:
# Use Cython
from Cython.Build import cythonize
setup_requires = []
ext_modules = cythonize(
Extension(
"netfilterqueue", ["netfilterqueue.pyx"], libraries=["netfilter_queue"]
@ -15,7 +16,11 @@ try:
)
except ImportError:
# No Cython
if not os.path.exists(os.path.join(os.path.dirname(__file__), "netfilterqueue.c")):
if "egg_info" in sys.argv:
# We're being run by pip to figure out what we need. Request cython in
# setup_requires below.
setup_requires = ["cython"]
elif not os.path.exists(os.path.join(os.path.dirname(__file__), "netfilterqueue.c")):
sys.stderr.write(
"You must have Cython installed (`pip install cython`) to build this "
"package from source.\nIf you're receiving this error when installing from "
@ -29,6 +34,7 @@ except ImportError:
setup(
ext_modules=ext_modules,
setup_requires=setup_requires,
name="NetfilterQueue",
version=VERSION,
license="MIT",

View File

@ -7,7 +7,8 @@ import sys
import trio
import unshare
import netfilterqueue
from typing import AsyncIterator
from functools import partial
from typing import AsyncIterator, Callable, Optional
from async_generator import asynccontextmanager
from pytest_trio.enable_trio_mode import *
@ -93,7 +94,7 @@ async def peer_main(idx: int, parent_fd: int) -> None:
# Enter the message-forwarding loop
async def proxy_one_way(src, dest):
while True:
while src.fileno() >= 0:
try:
msg = await src.recv(4096)
except trio.ClosedResourceError:
@ -111,6 +112,14 @@ async def peer_main(idx: int, parent_fd: int) -> None:
nursery.start_soon(proxy_one_way, peer, parent)
def _default_capture_cb(
target: "trio.MemorySendChannel[netfilterqueue.Packet]",
packet: netfilterqueue.Packet,
) -> None:
packet.retain()
target.send_nowait(packet)
class Harness:
def __init__(self):
self._received = {}
@ -155,7 +164,9 @@ class Harness:
"peer subprocess exited with code {}".format(retval)
)
finally:
await trio.run_process(f"ip link delete veth{idx}".split())
# On some kernels the veth device is removed when the subprocess exits
# and its netns goes away. check=False to suppress that error.
await trio.run_process(f"ip link delete veth{idx}".split(), check=False)
async def _manage_peer(self, idx: int, *, task_status):
async with trio.open_nursery() as nursery:
@ -192,24 +203,28 @@ class Harness:
@asynccontextmanager
async def capture_packets_to(
self, idx: int, *, queue_num: int = -1, **options
self,
idx: int,
cb: Callable[
["trio.MemorySendChannel[netfilterqueue.Packet]", netfilterqueue.Packet],
None,
] = _default_capture_cb,
*,
queue_num: int = -1,
**options: int,
) -> AsyncIterator["trio.MemoryReceiveChannel[netfilterqueue.Packet]"]:
packets_w, packets_r = trio.open_memory_channel(math.inf)
def stash_packet(p):
p.retain()
packets_w.send_nowait(p)
nfq = netfilterqueue.NetfilterQueue()
# Use a smaller socket buffer to avoid a warning in CI
options.setdefault("sock_len", 131072)
if queue_num >= 0:
nfq.bind(queue_num, stash_packet, **options)
nfq.bind(queue_num, partial(cb, packets_w), **options)
else:
for queue_num in range(16):
try:
nfq.bind(queue_num, stash_packet, **options)
nfq.bind(queue_num, partial(cb, packets_w), **options)
break
except Exception as ex:
last_error = ex

View File

@ -1,6 +1,11 @@
import struct
import trio
import trio.testing
import pytest
import signal
import sys
from netfilterqueue import NetfilterQueue
async def test_comms_without_queue(harness):
@ -61,6 +66,7 @@ async def test_rewrite_reorder(harness):
payload = packet.get_payload()[28:]
if payload == b"one":
set_udp_payload(packet, b"numero uno")
assert b"numero uno" == packet.get_payload()[28:]
packet.accept()
elif payload == b"two":
two = packet
@ -82,7 +88,6 @@ async def test_errors(harness):
with pytest.warns(RuntimeWarning, match="rcvbuf limit is") as record:
async with harness.capture_packets_to(2, sock_len=2 ** 30):
pass
assert record[0].filename.endswith("conftest.py")
async with harness.capture_packets_to(2, queue_num=0):
@ -90,9 +95,120 @@ async def test_errors(harness):
async with harness.capture_packets_to(2, queue_num=0):
pass
from netfilterqueue import NetfilterQueue
nfq = NetfilterQueue()
nfq.bind(1, lambda p: None, sock_len=131072)
with pytest.raises(RuntimeError, match="A queue is already bound"):
nfq.bind(2, lambda p: None, sock_len=131072)
async def test_unretained(harness):
# Capture packets without retaining -> can't access payload
async with harness.capture_packets_to(2, trio.MemorySendChannel.send_nowait) as chan:
await harness.send(2, b"one", b"two")
accept = True
async for p in chan:
with pytest.raises(RuntimeError, match="Payload data is no longer available"):
p.get_payload()
# Can still issue verdicts though
if accept:
p.accept()
accept = False
else:
break
with pytest.raises(RuntimeError, match="Parent queue has already been unbound"):
p.drop()
await harness.expect(2, b"one")
async def test_cb_exception(harness):
pkt = None
def cb(channel, p):
nonlocal pkt
pkt = p
raise ValueError("test")
# Error raised within run():
with pytest.raises(ValueError, match="test"):
async with harness.capture_packets_to(2, cb):
await harness.send(2, b"boom")
with trio.fail_after(1):
try:
await trio.sleep_forever()
finally:
# At this point the error has been raised (since we were
# cancelled) but the queue is still open. We shouldn't
# be able to access the payload, since we didn't retain(),
# but verdicts should otherwise work.
with pytest.raises(RuntimeError, match="Payload data is no longer"):
pkt.get_payload()
pkt.accept()
await harness.expect(2, b"boom")
with pytest.raises(RuntimeError, match="Verdict already given for this packet"):
pkt.drop()
async def test_cb_exception_during_unbind(harness, capsys):
pkt = None
def cb(channel, p):
nonlocal pkt
pkt = p
raise ValueError("test")
if sys.version_info >= (3, 8):
from _pytest.unraisableexception import catch_unraisable_exception
else:
from contextlib import contextmanager
@contextmanager
def catch_unraisable_exception():
pass
with catch_unraisable_exception() as unraise, trio.CancelScope() as cscope:
async with harness.capture_packets_to(2, cb):
# Cancel the task that reads from netfilter:
cscope.cancel()
with trio.CancelScope(shield=True):
await trio.testing.wait_all_tasks_blocked()
# Now actually send the packet and wait for the report to appear
# (hopefully)
await harness.send(2, b"boom boom")
await trio.sleep(0.5)
# Exiting the block calls unbind() and raises the exception in the cb.
# It gets caught and discarded as unraisable.
if unraise:
assert unraise.unraisable
assert unraise.unraisable.object == "netfilterqueue callback during unbind"
assert unraise.unraisable.exc_type is ValueError
assert str(unraise.unraisable.exc_value) == "test"
if not unraise:
assert "Exception ignored in: 'netfilterqueue callback" in capsys.readouterr().err
with pytest.raises(RuntimeError, match="Payload data is no longer available"):
pkt.get_payload()
with pytest.raises(RuntimeError, match="Parent queue has already been unbound"):
pkt.drop()
def test_signal():
nfq = NetfilterQueue()
nfq.bind(1, lambda p: None, sock_len=131072)
def raise_alarm(sig, frame):
raise KeyboardInterrupt("brrrrrring!")
old_handler = signal.signal(signal.SIGALRM, raise_alarm)
old_timer = signal.setitimer(signal.ITIMER_REAL, 0.5, 0)
try:
with pytest.raises(KeyboardInterrupt, match="brrrrrring!") as exc_info:
nfq.run()
assert any("NetfilterQueue.run" in line.name for line in exc_info.traceback)
finally:
signal.setitimer(signal.ITIMER_REAL, *old_timer)
signal.signal(signal.SIGALRM, old_handler)