python-netfilterqueue/tests/test_basic.py

221 lines
7.7 KiB
Python

import struct
import trio
import trio.testing
import pytest
import signal
import sys
from netfilterqueue import NetfilterQueue
async def test_comms_without_queue(harness):
await harness.send(2, b"hello", b"world")
await harness.expect(2, b"hello", b"world")
await harness.send(1, b"it works?")
await harness.expect(1, b"it works?")
async def test_queue_dropping(harness):
async def drop(packets, msg):
async for packet in packets:
if packet.get_payload()[28:] == msg:
packet.drop()
else:
packet.accept()
async with trio.open_nursery() as nursery:
async with harness.capture_packets_to(2) as p2, harness.capture_packets_to(
1
) as p1:
nursery.start_soon(drop, p2, b"two")
nursery.start_soon(drop, p1, b"one")
await harness.send(2, b"one", b"two", b"three")
await harness.send(1, b"one", b"two", b"three")
await harness.expect(2, b"one", b"three")
await harness.expect(1, b"two", b"three")
# Once we stop capturing, everything gets through again:
await harness.send(2, b"one", b"two", b"three")
await harness.send(1, b"one", b"two", b"three")
await harness.expect(2, b"one", b"two", b"three")
await harness.expect(1, b"one", b"two", b"three")
async def test_rewrite_reorder(harness):
async def munge(packets):
def set_udp_payload(p, msg):
data = bytearray(p.get_payload())
old_len = len(data) - 28
if len(msg) != old_len:
data[2:4] = struct.pack(">H", len(msg) + 28)
data[24:26] = struct.pack(">H", len(msg) + 8)
# Recompute checksum too
data[10:12] = b"\x00\x00"
words = struct.unpack(">10H", data[:20])
cksum = sum(words)
while cksum >> 16:
cksum = (cksum & 0xFFFF) + (cksum >> 16)
data[10:12] = struct.pack(">H", cksum ^ 0xFFFF)
# Clear UDP checksum and set payload
data[28:] = msg
data[26:28] = b"\x00\x00"
p.set_payload(bytes(data))
async for packet in packets:
payload = packet.get_payload()[28:]
if payload == b"one":
set_udp_payload(packet, b"numero uno")
assert b"numero uno" == packet.get_payload()[28:]
packet.accept()
elif payload == b"two":
two = packet
elif payload == b"three":
set_udp_payload(two, b"TWO")
packet.accept()
two.accept()
else:
packet.accept()
async with trio.open_nursery() as nursery:
async with harness.capture_packets_to(2) as p2:
nursery.start_soon(munge, p2)
await harness.send(2, b"one", b"two", b"three", b"four")
await harness.expect(2, b"numero uno", b"three", b"TWO", b"four")
async def test_errors(harness):
with pytest.warns(RuntimeWarning, match="rcvbuf limit is") as record:
async with harness.capture_packets_to(2, sock_len=2 ** 30):
pass
assert record[0].filename.endswith("conftest.py")
async with harness.capture_packets_to(2, queue_num=0):
with pytest.raises(OSError, match="Failed to create queue"):
async with harness.capture_packets_to(2, queue_num=0):
pass
nfq = NetfilterQueue()
nfq.bind(1, lambda p: None, sock_len=131072)
with pytest.raises(RuntimeError, match="A queue is already bound"):
nfq.bind(2, lambda p: None, sock_len=131072)
async def test_unretained(harness):
# Capture packets without retaining -> can't access payload
async with harness.capture_packets_to(
2, trio.MemorySendChannel.send_nowait
) as chan:
await harness.send(2, b"one", b"two")
accept = True
async for p in chan:
with pytest.raises(
RuntimeError, match="Payload data is no longer available"
):
p.get_payload()
# Can still issue verdicts though
if accept:
p.accept()
accept = False
else:
break
with pytest.raises(RuntimeError, match="Parent queue has already been unbound"):
p.drop()
await harness.expect(2, b"one")
async def test_cb_exception(harness):
pkt = None
def cb(channel, p):
nonlocal pkt
pkt = p
raise ValueError("test")
# Error raised within run():
with pytest.raises(ValueError, match="test"):
async with harness.capture_packets_to(2, cb):
await harness.send(2, b"boom")
with trio.fail_after(1):
try:
await trio.sleep_forever()
finally:
# At this point the error has been raised (since we were
# cancelled) but the queue is still open. We shouldn't
# be able to access the payload, since we didn't retain(),
# but verdicts should otherwise work.
with pytest.raises(RuntimeError, match="Payload data is no longer"):
pkt.get_payload()
pkt.accept()
await harness.expect(2, b"boom")
with pytest.raises(RuntimeError, match="Verdict already given for this packet"):
pkt.drop()
async def test_cb_exception_during_unbind(harness, capsys):
pkt = None
def cb(channel, p):
nonlocal pkt
pkt = p
raise ValueError("test")
if sys.version_info >= (3, 8):
from _pytest.unraisableexception import catch_unraisable_exception
else:
from contextlib import contextmanager
@contextmanager
def catch_unraisable_exception():
yield
with catch_unraisable_exception() as unraise, trio.CancelScope() as cscope:
async with harness.capture_packets_to(2, cb):
# Cancel the task that reads from netfilter:
cscope.cancel()
with trio.CancelScope(shield=True):
await trio.testing.wait_all_tasks_blocked()
# Now actually send the packet and wait for the report to appear
# (hopefully)
await harness.send(2, b"boom boom")
await trio.sleep(0.5)
# Exiting the block calls unbind() and raises the exception in the cb.
# It gets caught and discarded as unraisable.
if unraise:
assert unraise.unraisable
assert unraise.unraisable.object == "netfilterqueue callback during unbind"
assert unraise.unraisable.exc_type is ValueError
assert str(unraise.unraisable.exc_value) == "test"
if not unraise:
assert (
"Exception ignored in: 'netfilterqueue callback" in capsys.readouterr().err
)
with pytest.raises(RuntimeError, match="Payload data is no longer available"):
pkt.get_payload()
with pytest.raises(RuntimeError, match="Parent queue has already been unbound"):
pkt.drop()
def test_signal():
nfq = NetfilterQueue()
nfq.bind(1, lambda p: None, sock_len=131072)
def raise_alarm(sig, frame):
raise KeyboardInterrupt("brrrrrring!")
old_handler = signal.signal(signal.SIGALRM, raise_alarm)
old_timer = signal.setitimer(signal.ITIMER_REAL, 0.5, 0)
try:
with pytest.raises(KeyboardInterrupt, match="brrrrrring!") as exc_info:
nfq.run()
assert any("NetfilterQueue.run" in line.name for line in exc_info.traceback)
finally:
signal.setitimer(signal.ITIMER_REAL, *old_timer)
signal.signal(signal.SIGALRM, old_handler)