python-netfilterqueue/tests/test_basic.py

264 lines
8.7 KiB
Python

import gc
import struct
import trio
import trio.testing
import pytest
import signal
import socket
import sys
import time
import weakref
from netfilterqueue import NetfilterQueue
async def test_comms_without_queue(harness):
await harness.send(2, b"hello", b"world")
await harness.expect(2, b"hello", b"world")
await harness.send(1, b"it works?")
await harness.expect(1, b"it works?")
async def test_queue_dropping(harness):
async def drop(packets, msg):
async for packet in packets:
if packet.get_payload()[28:] == msg:
packet.drop()
else:
packet.accept()
async with trio.open_nursery() as nursery:
async with harness.capture_packets_to(2) as p2, harness.capture_packets_to(
1
) as p1:
nursery.start_soon(drop, p2, b"two")
nursery.start_soon(drop, p1, b"one")
await harness.send(2, b"one", b"two", b"three")
await harness.send(1, b"one", b"two", b"three")
await harness.expect(2, b"one", b"three")
await harness.expect(1, b"two", b"three")
# Once we stop capturing, everything gets through again:
await harness.send(2, b"one", b"two", b"three")
await harness.send(1, b"one", b"two", b"three")
await harness.expect(2, b"one", b"two", b"three")
await harness.expect(1, b"one", b"two", b"three")
async def test_rewrite_reorder(harness):
async def munge(packets):
def set_udp_payload(p, msg):
data = bytearray(p.get_payload())
old_len = len(data) - 28
if len(msg) != old_len:
data[2:4] = struct.pack(">H", len(msg) + 28)
data[24:26] = struct.pack(">H", len(msg) + 8)
# Recompute checksum too
data[10:12] = b"\x00\x00"
words = struct.unpack(">10H", data[:20])
cksum = sum(words)
while cksum >> 16:
cksum = (cksum & 0xFFFF) + (cksum >> 16)
data[10:12] = struct.pack(">H", cksum ^ 0xFFFF)
# Clear UDP checksum and set payload
data[28:] = msg
data[26:28] = b"\x00\x00"
p.set_payload(bytes(data))
async for packet in packets:
payload = packet.get_payload()[28:]
if payload == b"one":
set_udp_payload(packet, b"numero uno")
assert b"numero uno" == packet.get_payload()[28:]
packet.accept()
elif payload == b"two":
two = packet
elif payload == b"three":
set_udp_payload(two, b"TWO")
packet.accept()
two.accept()
else:
packet.accept()
async with trio.open_nursery() as nursery:
async with harness.capture_packets_to(2) as p2:
nursery.start_soon(munge, p2)
await harness.send(2, b"one", b"two", b"three", b"four")
await harness.expect(2, b"numero uno", b"three", b"TWO", b"four")
async def test_mark_repeat(harness):
counter = 0
timestamps = []
def cb(chan, pkt):
nonlocal counter
assert pkt.get_mark() == counter
timestamps.append(pkt.get_timestamp())
if counter < 5:
counter += 1
pkt.set_mark(counter)
pkt.repeat()
assert pkt.get_mark() == counter
else:
pkt.accept()
async with harness.capture_packets_to(2, cb):
t0 = time.time()
await harness.send(2, b"testing")
await harness.expect(2, b"testing")
t1 = time.time()
assert counter == 5
# All iterations of the packet have the same timestamps
assert all(t == timestamps[0] for t in timestamps[1:])
assert t0 < timestamps[0] < t1
async def test_hwaddr(harness):
hwaddrs = []
def cb(pkt):
hwaddrs.append((pkt.get_hw(), pkt.hook, pkt.get_payload()[28:]))
pkt.accept()
queue_num, nfq = harness.bind_queue(cb)
try:
async with trio.open_nursery() as nursery:
@nursery.start_soon
async def listen_for_packets():
while True:
await trio.lowlevel.wait_readable(nfq.get_fd())
nfq.run(block=False)
async with harness.enqueue_packets_to(2, queue_num, forwarded=True):
await harness.send(2, b"one", b"two")
await harness.expect(2, b"one", b"two")
async with harness.enqueue_packets_to(2, queue_num, forwarded=False):
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
for payload in (b"three", b"four"):
sock.sendto(payload, harness.dest_addr[2])
with trio.fail_after(1):
while len(hwaddrs) < 4:
await trio.sleep(0.1)
nursery.cancel_scope.cancel()
finally:
nfq.unbind()
# Forwarded packets capture a hwaddr, but OUTPUT don't
FORWARD = 2
OUTPUT = 3
mac1 = hwaddrs[0][0]
assert mac1 is not None
assert hwaddrs == [
(mac1, FORWARD, b"one"),
(mac1, FORWARD, b"two"),
(None, OUTPUT, b"three"),
(None, OUTPUT, b"four")
]
async def test_errors(harness):
with pytest.warns(RuntimeWarning, match="rcvbuf limit is") as record:
async with harness.capture_packets_to(2, sock_len=2 ** 30):
pass
assert record[0].filename.endswith("conftest.py")
async with harness.capture_packets_to(2, queue_num=0):
with pytest.raises(OSError, match="Failed to create queue"):
async with harness.capture_packets_to(2, queue_num=0):
pass
_, nfq = harness.bind_queue(lambda: None, queue_num=1)
with pytest.raises(RuntimeError, match="A queue is already bound"):
nfq.bind(2, lambda p: None)
# Test unbinding via __del__
nfq = weakref.ref(nfq)
for _ in range(4):
gc.collect()
if nfq() is None:
break
else:
raise RuntimeError("Couldn't trigger garbage collection of NFQ")
async def test_unretained(harness):
def cb(chan, pkt):
# Can access payload within callback
assert pkt.get_payload()[-3:] in (b"one", b"two")
chan.send_nowait(pkt)
# Capture packets without retaining -> can't access payload after cb returns
async with harness.capture_packets_to(2, cb) as chan:
await harness.send(2, b"one", b"two")
accept = True
async for p in chan:
with pytest.raises(
RuntimeError, match="Payload data is no longer available"
):
p.get_payload()
# Can still issue verdicts though
if accept:
p.accept()
accept = False
else:
break
with pytest.raises(RuntimeError, match="Parent queue has already been unbound"):
p.drop()
await harness.expect(2, b"one")
async def test_cb_exception(harness):
pkt = None
def cb(channel, p):
nonlocal pkt
pkt = p
raise ValueError("test")
# Error raised within run():
with pytest.raises(ValueError, match="test"):
async with harness.capture_packets_to(2, cb):
await harness.send(2, b"boom")
with trio.fail_after(1):
try:
await trio.sleep_forever()
finally:
# At this point the error has been raised (since we were
# cancelled) but the queue is still open. We shouldn't
# be able to access the payload, since we didn't retain(),
# but verdicts should otherwise work.
with pytest.raises(RuntimeError, match="Payload data is no longer"):
pkt.get_payload()
pkt.accept()
await harness.expect(2, b"boom")
with pytest.raises(RuntimeError, match="Verdict already given for this packet"):
pkt.drop()
@pytest.mark.skipif(
sys.implementation.name == "pypy",
reason="pypy does not support PyErr_CheckSignals",
)
def test_signal():
nfq = NetfilterQueue()
nfq.bind(1, lambda p: None, sock_len=131072)
def raise_alarm(sig, frame):
raise KeyboardInterrupt("brrrrrring!")
old_handler = signal.signal(signal.SIGALRM, raise_alarm)
old_timer = signal.setitimer(signal.ITIMER_REAL, 0.5, 0)
try:
with pytest.raises(KeyboardInterrupt, match="brrrrrring!") as exc_info:
nfq.run()
assert any("NetfilterQueue.run" in line.name for line in exc_info.traceback)
finally:
signal.setitimer(signal.ITIMER_REAL, *old_timer)
signal.signal(signal.SIGALRM, old_handler)