Merge pull request #74 from oremanj/updates

Modernize, add tests, allow Packet to outlive the callback it's passed to
This commit is contained in:
Joshua Oreman 2022-01-11 22:41:45 -07:00 committed by GitHub
commit afcee0d9bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 5131 additions and 1975 deletions

34
.github/workflows/ci.yml vendored Normal file
View File

@ -0,0 +1,34 @@
name: CI
on:
push:
branches:
- master
pull_request:
jobs:
Ubuntu:
name: 'Ubuntu (${{ matrix.python }})'
timeout-minutes: 10
runs-on: 'ubuntu-latest'
strategy:
fail-fast: false
matrix:
python:
- '3.6'
- '3.7'
- '3.8'
- '3.9'
- '3.10'
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Setup python
uses: actions/setup-python@v2
with:
python-version: ${{ fromJSON(format('["{0}", "{1}"]', format('{0}.0-alpha - {0}.X', matrix.python), matrix.python))[startsWith(matrix.python, 'pypy')] }}
- name: Run tests
run: ./ci.sh
env:
# Should match 'name:' up above
JOB_NAME: 'Ubuntu (${{ matrix.python }})'

View File

@ -1,4 +1,13 @@
v0.8.1 30 Jan 2017
v0.9.0, unreleased
Improve usability when Packet objects are retained past the callback
Add Packet.retain() to save the packet contents in such cases
Eliminate warnings during build on py3
Add CI and basic test suite
Raise a warning, not an error, if we don't get the bufsize we want
Don't allow bind() more than once on the same NetfilterQueue, since
that would leak the old queue handle
v0.8.1, 30 Jan 2017
Fix bug #25- crashing when used in OUTPUT or POSTROUTING chains
v0.8, 15 Dec 2016

View File

@ -3,9 +3,10 @@ NetfilterQueue
==============
NetfilterQueue provides access to packets matched by an iptables rule in
Linux. Packets so matched can be accepted, dropped, altered, or given a mark.
Linux. Packets so matched can be accepted, dropped, altered, reordered,
or given a mark.
Libnetfilter_queue (the netfilter library, not this module) is part of the
libnetfilter_queue (the netfilter library, not this module) is part of the
`Netfilter project <http://netfilter.org/projects/libnetfilter_queue/>`_.
Example
@ -15,18 +16,18 @@ The following script prints a short description of each packet before accepting
it. ::
from netfilterqueue import NetfilterQueue
def print_and_accept(pkt):
print(pkt)
pkt.accept()
nfqueue = NetfilterQueue()
nfqueue.bind(1, print_and_accept)
try:
nfqueue.run()
except KeyboardInterrupt:
print('')
nfqueue.unbind()
You can also make your own socket so that it can be used with gevent, for example. ::
@ -56,7 +57,7 @@ To send packets destined for your LAN to the script, type something like::
Installation
============
NetfilterQueue is a C extention module that links against libnetfilter_queue.
NetfilterQueue is a C extention module that links against libnetfilter_queue.
Before installing, ensure you have:
1. A C compiler
@ -81,9 +82,9 @@ From source
To install from source::
git clone git@github.com:kti/python-netfilterqueue.git
git clone https://github.com/oremanj/python-netfilterqueue
cd python-netfilterqueue
python setup.py install
pip install .
If Cython is installed, Distutils will use it to regenerate the .c source from the .pyx. It will then compile the .c into a .so.
@ -104,9 +105,12 @@ NetfilterQueue objects
A NetfilterQueue object represents a single queue. Configure your queue with
a call to ``bind``, then start receiving packets with a call to ``run``.
``QueueHandler.bind(queue_num, callback[, max_len[, mode[, range, [sock_len]]]])``
Create and bind to the queue. ``queue_num`` must match the number in your
iptables rule. ``callback`` is a function or method that takes one
``QueueHandler.bind(queue_num, callback[, max_len[, mode[, range[, sock_len]]]])``
Create and bind to the queue. ``queue_num`` uniquely identifies this
queue for the kernel. It must match the ``--queue-num`` in your iptables
rule, but there is no ordering requirement: it's fine to either ``bind()``
first or set up the iptables rule first.
``callback`` is a function or method that takes one
argument, a Packet object (see below). ``max_len`` sets the largest number
of packets that can be in the queue; new packets are dropped if the size of
the queue reaches this number. ``mode`` determines how much of the packet
@ -119,17 +123,21 @@ a call to ``bind``, then start receiving packets with a call to ``run``.
Remove the queue. Packets matched by your iptables rule will be dropped.
``QueueHandler.get_fd()``
Get the file descriptor of the queue handler.
Get the file descriptor of the socket used to receive queued
packets and send verdicts. If you're using an async event loop,
you can poll this FD for readability and call ``run(False)`` every
time data appears on it.
``QueueHandler.run([block])``
Send packets to your callback. By default, this method blocks. Set
block=False to let your thread continue. You can get the file descriptor
of the socket with the ``get_fd`` method.
Send packets to your callback. By default, this method blocks, running
until an exception is raised (such as by Ctrl+C). Set
block=False to process the pending messages without waiting for more.
You can get the file descriptor of the socket with the ``get_fd`` method.
``QueueHandler.run_socket(socket)``
Send packets to your callback, but use the supplied socket instead of
recv, so that, for example, gevent can monkeypatch it. You can make a
socket with ``socket.fromfd(nfqueue.get_fd(), socket.AF_UNIX, socket.SOCK_STREAM)``
socket with ``socket.fromfd(nfqueue.get_fd(), socket.AF_NETLINK, socket.SOCK_RAW)``
and optionally make it non-blocking with ``socket.setblocking(False)``.
Packet objects
@ -138,42 +146,65 @@ Packet objects
Objects of this type are passed to your callback.
``Packet.get_payload()``
Return the packet's payload as a string (Python 2) or bytes (Python 3).
Return the packet's payload as a bytes object. The returned value
starts with the IP header. You must call ``retain()`` if you want
to be able to ``get_payload()`` after your callback has returned.
``Packet.set_payload(payload)``
Set the packet payload. ``payload`` is a bytes.
Set the packet payload. Call this before ``accept()`` if you want to
change the contents of the packet before allowing it to be released.
Don't forget to update the transport-layer checksum (or clear it,
if you're using UDP), or else the recipient is likely to drop the
packet. If you're changing the length of the packet, you'll also need
to update the IP length, IP header checksum, and probably some
transport-level fields (such as UDP length for UDP).
``Packet.get_payload_len()``
Return the size of the payload.
``Packet.set_mark(mark)``
Give the packet a kernel mark. ``mark`` is a 32-bit number.
Give the packet a kernel mark, which can be used in future iptables
rules. ``mark`` is a 32-bit number.
``Packet.get_mark()``
Get the mark already on the packet.
Get the mark already on the packet (either the one you set using
``set_mark()``, or the one it arrived with if you haven't called
``set_mark()``).
``Packet.get_hw()``
Return the hardware address as a Python string.
``Packet.retain()``
Allocate a copy of the packet payload for use after the callback
has returned. ``get_payload()`` will raise an exception at that
point if you didn't call ``retain()``.
``Packet.accept()``
Accept the packet.
Accept the packet. You can reorder packets by accepting them
in a different order than the order in which they were passed
to your callback.
``Packet.drop()``
Drop the packet.
``Packet.repeat()``
Iterate the same cycle once more.
Restart processing of this packet from the beginning of its
Netfilter hook (iptables chain, roughly). Any changes made
using ``set_payload()`` or ``set_mark()`` are preserved; in the
absence of such changes, the packet will probably come right
back to the same queue.
Callback objects
----------------
Your callback can be function or a method and must accept one argument, a
Packet object. You must call either Packet.accept() or Packet.drop() before
returning.
``callback(packet)`` or ``callback(self, packet)``
Handle a single packet from the queue. You must call either
``packet.accept()`` or ``packet.drop()``.
Your callback can be any one-argument callable and will be invoked with
a ``Packet`` object as argument. You must call ``retain()`` within the
callback if you want to be able to ``get_payload()`` after the callback
has returned. You can hang onto ``Packet`` objects and resolve them later,
but note that packets continue to count against the queue size limit
until they've been given a verdict (accept, drop, or repeat). Also, the
kernel stores the enqueued packets in a linked list, so keeping lots of packets
outstanding is likely to adversely impact performance.
Usage
=====
@ -181,12 +212,12 @@ Usage
To send packets to the queue::
iptables -I <table or chain> <match specification> -j NFQUEUE --queue-num <queue number>
For example::
iptables -I INPUT -d 192.168.0.0/24 -j NFQUEUE --queue-num 1
The only special part of the rule is the target. Rules can have any match and
The only special part of the rule is the target. Rules can have any match and
can be added to any table or chain.
Valid queue numbers are integers from 0 to 65,535 inclusive.
@ -228,7 +259,7 @@ Limitations
* Omits methods for getting information about the interface a packet has
arrived on or is leaving on
* Probably other stuff is omitted too
Source
======
@ -237,7 +268,7 @@ https://github.com/kti/python-netfilterqueue
License
=======
Copyright (c) 2011, Kerkhoff Technologies, Inc.
Copyright (c) 2011, Kerkhoff Technologies, Inc, and contributors.
`MIT licensed <https://github.com/kti/python-netfilterqueue/blob/master/LICENSE.txt>`_

12
ci.sh Executable file
View File

@ -0,0 +1,12 @@
#!/bin/bash
set -ex -o pipefail
pip install -U pip setuptools wheel
sudo apt-get install libnetfilter-queue-dev
python setup.py sdist --formats=zip
pip install dist/*.zip
pip install -r test-requirements.txt
cd tests
pytest -W error -ra -v .

File diff suppressed because it is too large Load Diff

View File

@ -141,7 +141,7 @@ cdef extern from "libnetfilter_queue/libnetfilter_queue.h":
int nfq_fd(nfq_handle *h)
nfqnl_msg_packet_hdr *nfq_get_msg_packet_hdr(nfq_data *nfad)
int nfq_get_payload(nfq_data *nfad, char **data)
int nfq_get_payload(nfq_data *nfad, unsigned char **data)
int nfq_get_timestamp(nfq_data *nfad, timeval *tv)
nfqnl_msg_packet_hw *nfq_get_packet_hw(nfq_data *nfad)
int nfq_get_nfmark (nfq_data *nfad)
@ -168,14 +168,13 @@ cdef enum:
cdef class Packet:
cdef nfq_q_handle *_qh
cdef nfq_data *_nfa
cdef nfqnl_msg_packet_hdr *_hdr
cdef nfqnl_msg_packet_hw *_hw
cdef bint _verdict_is_set # True if verdict has been issued,
# false otherwise
cdef bint _mark_is_set # True if a mark has been given, false otherwise
cdef bint _hwaddr_is_set
cdef u_int32_t _given_mark # Mark given to packet
cdef bytes _given_payload # New payload of packet, or null
cdef bytes _owned_payload
# From NFQ packet header:
cdef readonly u_int32_t id
@ -185,7 +184,7 @@ cdef class Packet:
# Packet details:
cdef Py_ssize_t payload_len
cdef readonly char *payload
cdef readonly unsigned char *payload
cdef timeval timestamp
cdef u_int8_t hw_addr[8]
@ -198,12 +197,15 @@ cdef class Packet:
#cdef readonly u_int32_t physoutdev
cdef set_nfq_data(self, nfq_q_handle *qh, nfq_data *nfa)
cdef drop_refs(self)
cdef void verdict(self, u_int8_t verdict)
cpdef Py_ssize_t get_payload_len(self)
cpdef double get_timestamp(self)
cpdef bytes get_payload(self)
cpdef set_payload(self, bytes payload)
cpdef set_mark(self, u_int32_t mark)
cpdef get_mark(self)
cpdef retain(self)
cpdef accept(self)
cpdef drop(self)
cpdef repeat(self)

View File

@ -22,9 +22,21 @@ DEF MaxCopySize = BufferSize - MetadataSize
DEF SockOverhead = 760+20
DEF SockCopySize = MaxCopySize + SockOverhead
# Socket queue should hold max number of packets of copysize bytes
DEF SockRcvSize = DEFAULT_MAX_QUEUELEN * SockCopySize / 2
DEF SockRcvSize = DEFAULT_MAX_QUEUELEN * SockCopySize // 2
cdef extern from "Python.h":
const char* __FILE__
int __LINE__
cdef extern from *:
"""
#if PY_MAJOR_VERSION < 3
#define PyBytes_FromStringAndSize PyString_FromStringAndSize
#endif
"""
import socket
import warnings
cimport cpython.version
cdef int global_callback(nfq_q_handle *qh, nfgenmsg *nfmsg,
@ -35,6 +47,7 @@ cdef int global_callback(nfq_q_handle *qh, nfgenmsg *nfmsg,
packet = Packet()
packet.set_nfq_data(qh, nfa)
user_callback(packet)
packet.drop_refs()
return 1
cdef class Packet:
@ -54,21 +67,37 @@ cdef class Packet:
Assign a packet from NFQ to this object. Parse the header and load
local values.
"""
cdef nfqnl_msg_packet_hw *hw
cdef nfqnl_msg_packet_hdr *hdr
hdr = nfq_get_msg_packet_hdr(nfa)
self._qh = qh
self._nfa = nfa
self._hdr = nfq_get_msg_packet_hdr(nfa)
self.id = ntohl(hdr.packet_id)
self.hw_protocol = ntohs(hdr.hw_protocol)
self.hook = hdr.hook
self.id = ntohl(self._hdr.packet_id)
self.hw_protocol = ntohs(self._hdr.hw_protocol)
self.hook = self._hdr.hook
hw = nfq_get_packet_hw(nfa)
if hw == NULL:
# nfq_get_packet_hw doesn't work on OUTPUT and PREROUTING chains
self._hwaddr_is_set = False
else:
self.hw_addr = hw.hw_addr
self._hwaddr_is_set = True
self.payload_len = nfq_get_payload(self._nfa, &self.payload)
self.payload_len = nfq_get_payload(nfa, &self.payload)
if self.payload_len < 0:
raise OSError("Failed to get payload of packet.")
nfq_get_timestamp(self._nfa, &self.timestamp)
nfq_get_timestamp(nfa, &self.timestamp)
self.mark = nfq_get_nfmark(nfa)
cdef drop_refs(self):
"""
Called at the end of the user_callback, when the storage passed to
set_nfq_data() is about to be deallocated.
"""
self.payload = NULL
cdef void verdict(self, u_int8_t verdict):
"""Call appropriate set_verdict... function on packet."""
if self._verdict_is_set:
@ -99,23 +128,23 @@ cdef class Packet:
def get_hw(self):
"""Return the hardware address as Python string."""
self._hw = nfq_get_packet_hw(self._nfa)
if self._hw == NULL:
# nfq_get_packet_hw doesn't work on OUTPUT and PREROUTING chains
return None
self.hw_addr = self._hw.hw_addr
cdef object py_string
if cpython.version.PY_MAJOR_VERSION >= 3:
py_string = PyBytes_FromStringAndSize(<char*>self.hw_addr, 8)
else:
py_string = PyString_FromStringAndSize(<char*>self.hw_addr, 8)
py_string = PyBytes_FromStringAndSize(<char*>self.hw_addr, 8)
return py_string
def get_payload(self):
cpdef bytes get_payload(self):
"""Return payload as Python string."""
cdef object py_string
py_string = self.payload[:self.payload_len]
return py_string
if self._owned_payload:
return self._owned_payload
elif self.payload != NULL:
return self.payload[:self.payload_len]
else:
raise RuntimeError(
"Payload data is no longer available. You must call "
"retain() within the user_callback in order to copy "
"the payload if you need to expect it after your "
"callback has returned."
)
cpdef Py_ssize_t get_payload_len(self):
return self.payload_len
@ -136,6 +165,9 @@ cdef class Packet:
return self._given_mark
return self.mark
cpdef retain(self):
self._owned_payload = self.get_payload()
cpdef accept(self):
"""Accept the packet."""
self.verdict(NF_ACCEPT)
@ -174,6 +206,9 @@ cdef class NetfilterQueue:
u_int32_t range=MaxPacketSize,
u_int32_t sock_len=SockRcvSize):
"""Create and bind to a new queue."""
if self.qh != NULL:
raise RuntimeError("A queue is already bound; use unbind() first")
cdef unsigned int newsiz
self.user_callback = user_callback
self.qh = nfq_create_queue(self.h, queue_num,
@ -184,14 +219,24 @@ cdef class NetfilterQueue:
if range > MaxCopySize:
range = MaxCopySize
if nfq_set_mode(self.qh, mode, range) < 0:
self.unbind()
raise OSError("Failed to set packet copy mode.")
nfq_set_queue_maxlen(self.qh, max_len)
newsiz = nfnl_rcvbufsiz(nfq_nfnlh(self.h),sock_len)
if newsiz != sock_len*2:
raise RuntimeWarning("Socket rcvbuf limit is now %d, requested %d." % (newsiz,sock_len))
newsiz = nfnl_rcvbufsiz(nfq_nfnlh(self.h), sock_len)
if newsiz != sock_len * 2:
try:
warnings.warn_explicit(
"Socket rcvbuf limit is now %d, requested %d." % (newsiz, sock_len),
category=RuntimeWarning,
filename=bytes(__FILE__).decode("ascii"),
lineno=__LINE__,
)
except: # if warnings are being treated as errors
self.unbind()
raise
def unbind(self):
"""Destroy the queue."""
if self.qh != NULL:

View File

@ -1,38 +1,34 @@
from distutils.core import setup, Extension
from setuptools import setup, Extension
VERSION = "0.8.1" # Remember to change CHANGES.txt and netfilterqueue.pyx when version changes.
try:
# Use Cython
from Cython.Distutils import build_ext
cmd = {"build_ext": build_ext}
ext = Extension(
"netfilterqueue",
sources=["netfilterqueue.pyx",],
libraries=["netfilter_queue"],
)
from Cython.Build import cythonize
ext_modules = cythonize(
Extension(
"netfilterqueue", ["netfilterqueue.pyx"], libraries=["netfilter_queue"]
),
compiler_directives={"language_level": "3str"},
)
except ImportError:
# No Cython
cmd = {}
ext = Extension(
"netfilterqueue",
sources = ["netfilterqueue.c"],
libraries=["netfilter_queue"],
)
ext_modules = [
Extension("netfilterqueue", ["netfilterqueue.c"], libraries=["netfilter_queue"])
]
setup(
cmdclass = cmd,
ext_modules = [ext],
ext_modules=ext_modules,
name="NetfilterQueue",
version=VERSION,
license="MIT",
author="Matthew Fox",
author_email="matt@tansen.ca",
url="https://github.com/kti/python-netfilterqueue",
url="https://github.com/oremanj/python-netfilterqueue",
description="Python bindings for libnetfilter_queue",
long_description=open("README.rst").read(),
download_url="http://pypi.python.org/packages/source/N/NetfilterQueue/NetfilterQueue-%s.tar.gz" % VERSION,
classifiers = [
classifiers=[
"Development Status :: 5 - Production/Stable",
"License :: OSI Approved :: MIT License",
"Operating System :: POSIX :: Linux",

5
test-requirements.in Normal file
View File

@ -0,0 +1,5 @@
git+https://github.com/NightTsarina/python-unshare.git@4e98c177bdeb24c5dcfcd66c457845a776bbb75c
pytest
trio
pytest-trio
async_generator

50
test-requirements.txt Normal file
View File

@ -0,0 +1,50 @@
#
# This file is autogenerated by pip-compile with python 3.9
# To update, run:
#
# pip-compile test-requirements.in
#
async-generator==1.10
# via
# -r test-requirements.in
# pytest-trio
# trio
attrs==21.4.0
# via
# outcome
# pytest
# trio
idna==3.3
# via trio
iniconfig==1.1.1
# via pytest
outcome==1.1.0
# via
# pytest-trio
# trio
packaging==21.3
# via pytest
pluggy==1.0.0
# via pytest
py==1.11.0
# via pytest
pyparsing==3.0.6
# via packaging
pytest==6.2.5
# via
# -r test-requirements.in
# pytest-trio
pytest-trio==0.7.0
# via -r test-requirements.in
python-unshare @ git+https://github.com/NightTsarina/python-unshare.git@4e98c177bdeb24c5dcfcd66c457845a776bbb75c
# via -r test-requirements.in
sniffio==1.2.0
# via trio
sortedcontainers==2.4.0
# via trio
toml==0.10.2
# via pytest
trio==0.19.0
# via
# -r test-requirements.in
# pytest-trio

270
tests/conftest.py Normal file
View File

@ -0,0 +1,270 @@
import math
import os
import pytest
import socket
import subprocess
import sys
import trio
import unshare
import netfilterqueue
from typing import AsyncIterator
from async_generator import asynccontextmanager
from pytest_trio.enable_trio_mode import *
# We'll create three network namespaces, representing a router (which
# has interfaces on ROUTER_IP[1, 2]) and two hosts connected to it
# (PEER_IP[1, 2] respectively). The router (in the parent pytest
# process) will configure netfilterqueue iptables rules and use them
# to intercept and modify traffic between the two hosts (each of which
# is implemented in a subprocess).
#
# The 'peer' subprocesses communicate with each other over UDP, and
# with the router parent over a UNIX domain SOCK_SEQPACKET socketpair.
# Each packet sent from the parent to one peer over the UNIX domain
# socket will be forwarded to the other peer over UDP. Each packet
# received over UDP by either of the peers will be forwarded to its
# parent.
ROUTER_IP = {1: "172.16.101.1", 2: "172.16.102.1"}
PEER_IP = {1: "172.16.101.2", 2: "172.16.102.2"}
def enter_netns() -> None:
# Create new namespaces of the other types we need
unshare.unshare(unshare.CLONE_NEWNS | unshare.CLONE_NEWNET)
# Mount /sys so network tools work
subprocess.run("/bin/mount -t sysfs sys /sys".split(), check=True)
# Bind-mount /run so iptables can get its lock
subprocess.run("/bin/mount -t tmpfs tmpfs /run".split(), check=True)
# Set up loopback interface
subprocess.run("/sbin/ip link set lo up".split(), check=True)
@pytest.hookimpl(tryfirst=True)
def pytest_runtestloop():
if os.getuid() != 0:
# Create a new user namespace for the whole test session
outer = {"uid": os.getuid(), "gid": os.getgid()}
unshare.unshare(unshare.CLONE_NEWUSER)
with open("/proc/self/setgroups", "wb") as fp:
# This is required since we're unprivileged outside the namespace
fp.write(b"deny")
for idtype in ("uid", "gid"):
with open(f"/proc/self/{idtype}_map", "wb") as fp:
fp.write(b"0 %d 1" % (outer[idtype],))
assert os.getuid() == os.getgid() == 0
# Create a new network namespace for this pytest process
enter_netns()
with open("/proc/sys/net/ipv4/ip_forward", "wb") as fp:
fp.write(b"1\n")
async def peer_main(idx: int, parent_fd: int) -> None:
parent = trio.socket.fromfd(
parent_fd, socket.AF_UNIX, socket.SOCK_SEQPACKET
)
# Tell parent we've set up our netns, wait for it to confirm it's
# created our veth interface
await parent.send(b"ok")
assert b"ok" == await parent.recv(4096)
my_ip = PEER_IP[idx]
router_ip = ROUTER_IP[idx]
peer_ip = PEER_IP[3 - idx]
for cmd in (
f"ip link set veth0 up",
f"ip addr add {my_ip}/24 dev veth0",
f"ip route add default via {router_ip} dev veth0",
):
await trio.run_process(
cmd.split(), capture_stdout=True, capture_stderr=True
)
peer = trio.socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
await peer.bind((my_ip, 0))
# Tell the parent our port and get our peer's port
await parent.send(b"%d" % peer.getsockname()[1])
peer_port = int(await parent.recv(4096))
await peer.connect((peer_ip, peer_port))
# Enter the message-forwarding loop
async def proxy_one_way(src, dest):
while True:
try:
msg = await src.recv(4096)
except trio.ClosedResourceError:
return
if not msg:
dest.close()
return
try:
await dest.send(msg)
except BrokenPipeError:
return
async with trio.open_nursery() as nursery:
nursery.start_soon(proxy_one_way, parent, peer)
nursery.start_soon(proxy_one_way, peer, parent)
class Harness:
def __init__(self):
self._received = {}
self._conn = {}
self.failed = False
async def _run_peer(self, idx: int, *, task_status):
their_ip = PEER_IP[idx]
my_ip = ROUTER_IP[idx]
conn, child_conn = trio.socket.socketpair(socket.AF_UNIX, socket.SOCK_SEQPACKET)
with conn:
try:
process = await trio.open_process(
[sys.executable, __file__, str(idx), str(child_conn.fileno())],
stdin=subprocess.DEVNULL,
pass_fds=[child_conn.fileno()],
preexec_fn=enter_netns,
)
finally:
child_conn.close()
assert b"ok" == await conn.recv(4096)
for cmd in (
f"ip link add veth{idx} type veth peer netns {process.pid} name veth0",
f"ip link set veth{idx} up",
f"ip addr add {my_ip}/24 dev veth{idx}",
):
await trio.run_process(cmd.split())
try:
await conn.send(b"ok")
self._conn[idx] = conn
task_status.started()
retval = await process.wait()
except BaseException:
process.kill()
with trio.CancelScope(shield=True):
await process.wait()
raise
else:
if retval != 0:
raise RuntimeError(
"peer subprocess exited with code {}".format(retval)
)
finally:
await trio.run_process(f"ip link delete veth{idx}".split())
async def _manage_peer(self, idx: int, *, task_status):
async with trio.open_nursery() as nursery:
await nursery.start(self._run_peer, idx)
packets_w, packets_r = trio.open_memory_channel(math.inf)
self._received[idx] = packets_r
task_status.started()
async with packets_w:
while True:
msg = await self._conn[idx].recv(4096)
if not msg:
break
await packets_w.send(msg)
@asynccontextmanager
async def run(self):
async with trio.open_nursery() as nursery:
async with trio.open_nursery() as start_nursery:
start_nursery.start_soon(nursery.start, self._manage_peer, 1)
start_nursery.start_soon(nursery.start, self._manage_peer, 2)
# Tell each peer about the other one's port
await self._conn[2].send(await self._received[1].receive())
await self._conn[1].send(await self._received[2].receive())
yield
self._conn[1].shutdown(socket.SHUT_WR)
self._conn[2].shutdown(socket.SHUT_WR)
if not self.failed:
for idx in (1, 2):
async for remainder in self._received[idx]:
raise AssertionError(
f"Peer {idx} received unexepcted packet {remainder!r}"
)
@asynccontextmanager
async def capture_packets_to(
self, idx: int, *, queue_num: int = -1, **options
) -> AsyncIterator["trio.MemoryReceiveChannel[netfilterqueue.Packet]"]:
packets_w, packets_r = trio.open_memory_channel(math.inf)
def stash_packet(p):
p.retain()
packets_w.send_nowait(p)
nfq = netfilterqueue.NetfilterQueue()
# Use a smaller socket buffer to avoid a warning in CI
options.setdefault("sock_len", 131072)
if queue_num >= 0:
nfq.bind(queue_num, stash_packet, **options)
else:
for queue_num in range(16):
try:
nfq.bind(queue_num, stash_packet, **options)
break
except Exception as ex:
last_error = ex
else:
raise RuntimeError(
"Couldn't bind any netfilter queue number between 0-15"
) from last_error
try:
rule = f"-d {PEER_IP[idx]} -j NFQUEUE --queue-num {queue_num}"
await trio.run_process(f"/sbin/iptables -A FORWARD {rule}".split())
try:
async with packets_w, trio.open_nursery() as nursery:
@nursery.start_soon
async def listen_for_packets():
while True:
await trio.lowlevel.wait_readable(nfq.get_fd())
nfq.run(block=False)
yield packets_r
nursery.cancel_scope.cancel()
finally:
await trio.run_process(f"/sbin/iptables -D FORWARD {rule}".split())
finally:
nfq.unbind()
async def expect(self, idx: int, *packets: bytes):
for expected in packets:
with trio.move_on_after(5) as scope:
received = await self._received[idx].receive()
if scope.cancelled_caught:
self.failed = True
raise AssertionError(
f"Timeout waiting for peer {idx} to receive {expected!r}"
)
if received != expected:
self.failed = True
raise AssertionError(
f"Expected peer {idx} to receive {expected!r} but it "
f"received {received!r}"
)
async def send(self, idx: int, *packets: bytes):
for packet in packets:
await self._conn[3 - idx].send(packet)
@pytest.fixture
async def harness() -> Harness:
h = Harness()
async with h.run():
yield h
if __name__ == "__main__":
trio.run(peer_main, int(sys.argv[1]), int(sys.argv[2]))

95
tests/test_basic.py Normal file
View File

@ -0,0 +1,95 @@
import struct
import trio
import pytest
async def test_comms_without_queue(harness):
await harness.send(2, b"hello", b"world")
await harness.expect(2, b"hello", b"world")
await harness.send(1, b"it works?")
await harness.expect(1, b"it works?")
async def test_queue_dropping(harness):
async def drop(packets, msg):
async for packet in packets:
if packet.get_payload()[28:] == msg:
packet.drop()
else:
packet.accept()
async with trio.open_nursery() as nursery:
async with harness.capture_packets_to(2) as p2, \
harness.capture_packets_to(1) as p1:
nursery.start_soon(drop, p2, b"two")
nursery.start_soon(drop, p1, b"one")
await harness.send(2, b"one", b"two", b"three")
await harness.send(1, b"one", b"two", b"three")
await harness.expect(2, b"one", b"three")
await harness.expect(1, b"two", b"three")
# Once we stop capturing, everything gets through again:
await harness.send(2, b"one", b"two", b"three")
await harness.send(1, b"one", b"two", b"three")
await harness.expect(2, b"one", b"two", b"three")
await harness.expect(1, b"one", b"two", b"three")
async def test_rewrite_reorder(harness):
async def munge(packets):
def set_udp_payload(p, msg):
data = bytearray(p.get_payload())
old_len = len(data) - 28
if len(msg) != old_len:
data[2:4] = struct.pack(">H", len(msg) + 28)
data[24:26] = struct.pack(">H", len(msg) + 8)
# Recompute checksum too
data[10:12] = b"\x00\x00"
words = struct.unpack(">10H", data[:20])
cksum = sum(words)
while cksum >> 16:
cksum = (cksum & 0xFFFF) + (cksum >> 16)
data[10:12] = struct.pack(">H", cksum ^ 0xFFFF)
# Clear UDP checksum and set payload
data[28:] = msg
data[26:28] = b"\x00\x00"
p.set_payload(bytes(data))
async for packet in packets:
payload = packet.get_payload()[28:]
if payload == b"one":
set_udp_payload(packet, b"numero uno")
packet.accept()
elif payload == b"two":
two = packet
elif payload == b"three":
set_udp_payload(two, b"TWO")
packet.accept()
two.accept()
else:
packet.accept()
async with trio.open_nursery() as nursery:
async with harness.capture_packets_to(2) as p2:
nursery.start_soon(munge, p2)
await harness.send(2, b"one", b"two", b"three", b"four")
await harness.expect(2, b"numero uno", b"three", b"TWO", b"four")
async def test_errors(harness):
with pytest.warns(RuntimeWarning, match="rcvbuf limit is"):
async with harness.capture_packets_to(2, sock_len=2**30):
pass
async with harness.capture_packets_to(2, queue_num=0):
with pytest.raises(OSError, match="Failed to create queue"):
async with harness.capture_packets_to(2, queue_num=0):
pass
from netfilterqueue import NetfilterQueue
nfq = NetfilterQueue()
nfq.bind(1, lambda p: None, sock_len=131072)
with pytest.raises(RuntimeError, match="A queue is already bound"):
nfq.bind(2, lambda p: None, sock_len=131072)