Modernize, add tests, allow Packet to outlive the callback it's passed to
This commit is contained in:
parent
ec2ae29066
commit
9587d75aff
|
@ -0,0 +1,34 @@
|
|||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
Ubuntu:
|
||||
name: 'Ubuntu (${{ matrix.python }})'
|
||||
timeout-minutes: 10
|
||||
runs-on: 'ubuntu-latest'
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python:
|
||||
- '3.6'
|
||||
- '3.7'
|
||||
- '3.8'
|
||||
- '3.9'
|
||||
- '3.10'
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
- name: Setup python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ fromJSON(format('["{0}", "{1}"]', format('{0}.0-alpha - {0}.X', matrix.python), matrix.python))[startsWith(matrix.python, 'pypy')] }}
|
||||
- name: Run tests
|
||||
run: ./ci.sh
|
||||
env:
|
||||
# Should match 'name:' up above
|
||||
JOB_NAME: 'Ubuntu (${{ matrix.python }})'
|
|
@ -1,4 +1,10 @@
|
|||
v0.8.1 30 Jan 2017
|
||||
v0.9.0, unreleased
|
||||
Improve usability when Packet objects are retained past the callback
|
||||
Add Packet.retain() to save the packet contents in such cases
|
||||
Eliminate warnings during build on py3
|
||||
Add CI and basic test suite
|
||||
|
||||
v0.8.1, 30 Jan 2017
|
||||
Fix bug #25- crashing when used in OUTPUT or POSTROUTING chains
|
||||
|
||||
v0.8, 15 Dec 2016
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
#!/bin/bash
|
||||
|
||||
set -ex -o pipefail
|
||||
|
||||
pip install -U pip setuptools wheel
|
||||
sudo apt-get install libnetfilterqueue-dev
|
||||
python setup.py sdist --formats=zip
|
||||
pip install dist/*.zip
|
||||
pip install -r test-requirements.txt
|
||||
|
||||
cd tests
|
||||
pytest -W error -ra -v .
|
|
@ -141,7 +141,7 @@ cdef extern from "libnetfilter_queue/libnetfilter_queue.h":
|
|||
|
||||
int nfq_fd(nfq_handle *h)
|
||||
nfqnl_msg_packet_hdr *nfq_get_msg_packet_hdr(nfq_data *nfad)
|
||||
int nfq_get_payload(nfq_data *nfad, char **data)
|
||||
int nfq_get_payload(nfq_data *nfad, unsigned char **data)
|
||||
int nfq_get_timestamp(nfq_data *nfad, timeval *tv)
|
||||
nfqnl_msg_packet_hw *nfq_get_packet_hw(nfq_data *nfad)
|
||||
int nfq_get_nfmark (nfq_data *nfad)
|
||||
|
@ -168,14 +168,13 @@ cdef enum:
|
|||
|
||||
cdef class Packet:
|
||||
cdef nfq_q_handle *_qh
|
||||
cdef nfq_data *_nfa
|
||||
cdef nfqnl_msg_packet_hdr *_hdr
|
||||
cdef nfqnl_msg_packet_hw *_hw
|
||||
cdef bint _verdict_is_set # True if verdict has been issued,
|
||||
# false otherwise
|
||||
cdef bint _mark_is_set # True if a mark has been given, false otherwise
|
||||
cdef bint _hwaddr_is_set
|
||||
cdef u_int32_t _given_mark # Mark given to packet
|
||||
cdef bytes _given_payload # New payload of packet, or null
|
||||
cdef bytes _owned_payload
|
||||
|
||||
# From NFQ packet header:
|
||||
cdef readonly u_int32_t id
|
||||
|
@ -185,7 +184,7 @@ cdef class Packet:
|
|||
|
||||
# Packet details:
|
||||
cdef Py_ssize_t payload_len
|
||||
cdef readonly char *payload
|
||||
cdef readonly unsigned char *payload
|
||||
cdef timeval timestamp
|
||||
cdef u_int8_t hw_addr[8]
|
||||
|
||||
|
@ -198,12 +197,15 @@ cdef class Packet:
|
|||
#cdef readonly u_int32_t physoutdev
|
||||
|
||||
cdef set_nfq_data(self, nfq_q_handle *qh, nfq_data *nfa)
|
||||
cdef drop_refs(self)
|
||||
cdef void verdict(self, u_int8_t verdict)
|
||||
cpdef Py_ssize_t get_payload_len(self)
|
||||
cpdef double get_timestamp(self)
|
||||
cpdef bytes get_payload(self)
|
||||
cpdef set_payload(self, bytes payload)
|
||||
cpdef set_mark(self, u_int32_t mark)
|
||||
cpdef get_mark(self)
|
||||
cpdef retain(self)
|
||||
cpdef accept(self)
|
||||
cpdef drop(self)
|
||||
cpdef repeat(self)
|
||||
|
|
|
@ -22,7 +22,14 @@ DEF MaxCopySize = BufferSize - MetadataSize
|
|||
DEF SockOverhead = 760+20
|
||||
DEF SockCopySize = MaxCopySize + SockOverhead
|
||||
# Socket queue should hold max number of packets of copysize bytes
|
||||
DEF SockRcvSize = DEFAULT_MAX_QUEUELEN * SockCopySize / 2
|
||||
DEF SockRcvSize = DEFAULT_MAX_QUEUELEN * SockCopySize // 2
|
||||
|
||||
cdef extern from *:
|
||||
"""
|
||||
#if PY_MAJOR_VERSION < 3
|
||||
#define PyBytes_FromStringAndSize PyString_FromStringAndSize
|
||||
#endif
|
||||
"""
|
||||
|
||||
import socket
|
||||
cimport cpython.version
|
||||
|
@ -35,6 +42,7 @@ cdef int global_callback(nfq_q_handle *qh, nfgenmsg *nfmsg,
|
|||
packet = Packet()
|
||||
packet.set_nfq_data(qh, nfa)
|
||||
user_callback(packet)
|
||||
packet.drop_refs()
|
||||
return 1
|
||||
|
||||
cdef class Packet:
|
||||
|
@ -54,21 +62,37 @@ cdef class Packet:
|
|||
Assign a packet from NFQ to this object. Parse the header and load
|
||||
local values.
|
||||
"""
|
||||
cdef nfqnl_msg_packet_hw *hw
|
||||
cdef nfqnl_msg_packet_hdr *hdr
|
||||
|
||||
hdr = nfq_get_msg_packet_hdr(nfa)
|
||||
self._qh = qh
|
||||
self._nfa = nfa
|
||||
self._hdr = nfq_get_msg_packet_hdr(nfa)
|
||||
self.id = ntohl(hdr.packet_id)
|
||||
self.hw_protocol = ntohs(hdr.hw_protocol)
|
||||
self.hook = hdr.hook
|
||||
|
||||
self.id = ntohl(self._hdr.packet_id)
|
||||
self.hw_protocol = ntohs(self._hdr.hw_protocol)
|
||||
self.hook = self._hdr.hook
|
||||
hw = nfq_get_packet_hw(nfa)
|
||||
if hw == NULL:
|
||||
# nfq_get_packet_hw doesn't work on OUTPUT and PREROUTING chains
|
||||
self._hwaddr_is_set = False
|
||||
else:
|
||||
self.hw_addr = hw.hw_addr
|
||||
self._hwaddr_is_set = True
|
||||
|
||||
self.payload_len = nfq_get_payload(self._nfa, &self.payload)
|
||||
self.payload_len = nfq_get_payload(nfa, &self.payload)
|
||||
if self.payload_len < 0:
|
||||
raise OSError("Failed to get payload of packet.")
|
||||
|
||||
nfq_get_timestamp(self._nfa, &self.timestamp)
|
||||
nfq_get_timestamp(nfa, &self.timestamp)
|
||||
self.mark = nfq_get_nfmark(nfa)
|
||||
|
||||
cdef drop_refs(self):
|
||||
"""
|
||||
Called at the end of the user_callback, when the storage passed to
|
||||
set_nfq_data() is about to be deallocated.
|
||||
"""
|
||||
self.payload = NULL
|
||||
|
||||
cdef void verdict(self, u_int8_t verdict):
|
||||
"""Call appropriate set_verdict... function on packet."""
|
||||
if self._verdict_is_set:
|
||||
|
@ -99,23 +123,23 @@ cdef class Packet:
|
|||
|
||||
def get_hw(self):
|
||||
"""Return the hardware address as Python string."""
|
||||
self._hw = nfq_get_packet_hw(self._nfa)
|
||||
if self._hw == NULL:
|
||||
# nfq_get_packet_hw doesn't work on OUTPUT and PREROUTING chains
|
||||
return None
|
||||
self.hw_addr = self._hw.hw_addr
|
||||
cdef object py_string
|
||||
if cpython.version.PY_MAJOR_VERSION >= 3:
|
||||
py_string = PyBytes_FromStringAndSize(<char*>self.hw_addr, 8)
|
||||
else:
|
||||
py_string = PyString_FromStringAndSize(<char*>self.hw_addr, 8)
|
||||
py_string = PyBytes_FromStringAndSize(<char*>self.hw_addr, 8)
|
||||
return py_string
|
||||
|
||||
def get_payload(self):
|
||||
cpdef bytes get_payload(self):
|
||||
"""Return payload as Python string."""
|
||||
cdef object py_string
|
||||
py_string = self.payload[:self.payload_len]
|
||||
return py_string
|
||||
if self._owned_payload:
|
||||
return self._owned_payload
|
||||
elif self.payload != NULL:
|
||||
return self.payload[:self.payload_len]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Payload data is no longer available. You must call "
|
||||
"retain() within the user_callback in order to copy "
|
||||
"the payload if you need to expect it after your "
|
||||
"callback has returned."
|
||||
)
|
||||
|
||||
cpdef Py_ssize_t get_payload_len(self):
|
||||
return self.payload_len
|
||||
|
@ -136,6 +160,9 @@ cdef class Packet:
|
|||
return self._given_mark
|
||||
return self.mark
|
||||
|
||||
cpdef retain(self):
|
||||
self._owned_payload = self.get_payload()
|
||||
|
||||
cpdef accept(self):
|
||||
"""Accept the packet."""
|
||||
self.verdict(NF_ACCEPT)
|
||||
|
@ -191,7 +218,7 @@ cdef class NetfilterQueue:
|
|||
newsiz = nfnl_rcvbufsiz(nfq_nfnlh(self.h),sock_len)
|
||||
if newsiz != sock_len*2:
|
||||
raise RuntimeWarning("Socket rcvbuf limit is now %d, requested %d." % (newsiz,sock_len))
|
||||
|
||||
|
||||
def unbind(self):
|
||||
"""Destroy the queue."""
|
||||
if self.qh != NULL:
|
||||
|
|
32
setup.py
32
setup.py
|
@ -1,38 +1,34 @@
|
|||
from distutils.core import setup, Extension
|
||||
from setuptools import setup, Extension
|
||||
|
||||
VERSION = "0.8.1" # Remember to change CHANGES.txt and netfilterqueue.pyx when version changes.
|
||||
|
||||
try:
|
||||
# Use Cython
|
||||
from Cython.Distutils import build_ext
|
||||
cmd = {"build_ext": build_ext}
|
||||
ext = Extension(
|
||||
"netfilterqueue",
|
||||
sources=["netfilterqueue.pyx",],
|
||||
libraries=["netfilter_queue"],
|
||||
)
|
||||
from Cython.Build import cythonize
|
||||
ext_modules = cythonize(
|
||||
Extension(
|
||||
"netfilterqueue", ["netfilterqueue.pyx"], libraries=["netfilter_queue"]
|
||||
),
|
||||
compiler_directives={"language_level": "3str"},
|
||||
)
|
||||
except ImportError:
|
||||
# No Cython
|
||||
cmd = {}
|
||||
ext = Extension(
|
||||
"netfilterqueue",
|
||||
sources = ["netfilterqueue.c"],
|
||||
libraries=["netfilter_queue"],
|
||||
)
|
||||
ext_modules = [
|
||||
Extension("netfilterqueue", ["netfilterqueue.c"], libraries=["netfilter_queue"])
|
||||
]
|
||||
|
||||
setup(
|
||||
cmdclass = cmd,
|
||||
ext_modules = [ext],
|
||||
ext_modules=ext_modules,
|
||||
name="NetfilterQueue",
|
||||
version=VERSION,
|
||||
license="MIT",
|
||||
author="Matthew Fox",
|
||||
author_email="matt@tansen.ca",
|
||||
url="https://github.com/kti/python-netfilterqueue",
|
||||
url="https://github.com/oremanj/python-netfilterqueue",
|
||||
description="Python bindings for libnetfilter_queue",
|
||||
long_description=open("README.rst").read(),
|
||||
download_url="http://pypi.python.org/packages/source/N/NetfilterQueue/NetfilterQueue-%s.tar.gz" % VERSION,
|
||||
classifiers = [
|
||||
classifiers=[
|
||||
"Development Status :: 5 - Production/Stable",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: POSIX :: Linux",
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
git+https://github.com/NightTsarina/python-unshare.git@4e98c177bdeb24c5dcfcd66c457845a776bbb75c
|
||||
pytest
|
||||
trio
|
||||
pytest-trio
|
||||
async_generator
|
|
@ -0,0 +1,50 @@
|
|||
#
|
||||
# This file is autogenerated by pip-compile with python 3.9
|
||||
# To update, run:
|
||||
#
|
||||
# pip-compile test-requirements.in
|
||||
#
|
||||
async-generator==1.10
|
||||
# via
|
||||
# -r test-requirements.in
|
||||
# pytest-trio
|
||||
# trio
|
||||
attrs==21.4.0
|
||||
# via
|
||||
# outcome
|
||||
# pytest
|
||||
# trio
|
||||
idna==3.3
|
||||
# via trio
|
||||
iniconfig==1.1.1
|
||||
# via pytest
|
||||
outcome==1.1.0
|
||||
# via
|
||||
# pytest-trio
|
||||
# trio
|
||||
packaging==21.3
|
||||
# via pytest
|
||||
pluggy==1.0.0
|
||||
# via pytest
|
||||
py==1.11.0
|
||||
# via pytest
|
||||
pyparsing==3.0.6
|
||||
# via packaging
|
||||
pytest==6.2.5
|
||||
# via
|
||||
# -r test-requirements.in
|
||||
# pytest-trio
|
||||
pytest-trio==0.7.0
|
||||
# via -r test-requirements.in
|
||||
python-unshare @ git+https://github.com/NightTsarina/python-unshare.git@4e98c177bdeb24c5dcfcd66c457845a776bbb75c
|
||||
# via -r test-requirements.in
|
||||
sniffio==1.2.0
|
||||
# via trio
|
||||
sortedcontainers==2.4.0
|
||||
# via trio
|
||||
toml==0.10.2
|
||||
# via pytest
|
||||
trio==0.19.0
|
||||
# via
|
||||
# -r test-requirements.in
|
||||
# pytest-trio
|
|
@ -0,0 +1,268 @@
|
|||
import math
|
||||
import os
|
||||
import pytest
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import trio
|
||||
import unshare
|
||||
import netfilterqueue
|
||||
from typing import AsyncIterator
|
||||
from async_generator import asynccontextmanager
|
||||
from pytest_trio.enable_trio_mode import *
|
||||
|
||||
|
||||
# We'll create three network namespaces, representing a router (which
|
||||
# has interfaces on ROUTER_IP[1, 2]) and two hosts connected to it
|
||||
# (PEER_IP[1, 2] respectively). The router (in the parent pytest
|
||||
# process) will configure netfilterqueue iptables rules and use them
|
||||
# to intercept and modify traffic between the two hosts (each of which
|
||||
# is implemented in a subprocess).
|
||||
#
|
||||
# The 'peer' subprocesses communicate with each other over UDP, and
|
||||
# with the router parent over a UNIX domain SOCK_SEQPACKET socketpair.
|
||||
# Each packet sent from the parent to one peer over the UNIX domain
|
||||
# socket will be forwarded to the other peer over UDP. Each packet
|
||||
# received over UDP by either of the peers will be forwarded to its
|
||||
# parent.
|
||||
|
||||
ROUTER_IP = {1: "172.16.101.1", 2: "172.16.102.1"}
|
||||
PEER_IP = {1: "172.16.101.2", 2: "172.16.102.2"}
|
||||
|
||||
|
||||
def enter_netns() -> None:
|
||||
# Create new namespaces of the other types we need
|
||||
unshare.unshare(unshare.CLONE_NEWNS | unshare.CLONE_NEWNET)
|
||||
|
||||
# Mount /sys so network tools work
|
||||
subprocess.run("/bin/mount -t sysfs sys /sys".split(), check=True)
|
||||
|
||||
# Bind-mount /run so iptables can get its lock
|
||||
subprocess.run("/bin/mount -t tmpfs tmpfs /run".split(), check=True)
|
||||
|
||||
# Set up loopback interface
|
||||
subprocess.run("/sbin/ip link set lo up".split(), check=True)
|
||||
|
||||
|
||||
@pytest.hookimpl(tryfirst=True)
|
||||
def pytest_runtestloop():
|
||||
if os.getuid() != 0:
|
||||
# Create a new user namespace for the whole test session
|
||||
outer = {"uid": os.getuid(), "gid": os.getgid()}
|
||||
unshare.unshare(unshare.CLONE_NEWUSER)
|
||||
with open("/proc/self/setgroups", "wb") as fp:
|
||||
# This is required since we're unprivileged outside the namespace
|
||||
fp.write(b"deny")
|
||||
for idtype in ("uid", "gid"):
|
||||
with open(f"/proc/self/{idtype}_map", "wb") as fp:
|
||||
fp.write(b"0 %d 1" % (outer[idtype],))
|
||||
assert os.getuid() == os.getgid() == 0
|
||||
|
||||
# Create a new network namespace for this pytest process
|
||||
enter_netns()
|
||||
with open("/proc/sys/net/ipv4/ip_forward", "wb") as fp:
|
||||
fp.write(b"1\n")
|
||||
|
||||
|
||||
async def peer_main(idx: int, parent_fd: int) -> None:
|
||||
parent = trio.socket.fromfd(
|
||||
parent_fd, socket.AF_UNIX, socket.SOCK_SEQPACKET
|
||||
)
|
||||
|
||||
# Tell parent we've set up our netns, wait for it to confirm it's
|
||||
# created our veth interface
|
||||
await parent.send(b"ok")
|
||||
assert b"ok" == await parent.recv(4096)
|
||||
|
||||
my_ip = PEER_IP[idx]
|
||||
router_ip = ROUTER_IP[idx]
|
||||
peer_ip = PEER_IP[3 - idx]
|
||||
|
||||
for cmd in (
|
||||
f"ip link set veth0 up",
|
||||
f"ip addr add {my_ip}/24 dev veth0",
|
||||
f"ip route add default via {router_ip} dev veth0",
|
||||
):
|
||||
await trio.run_process(
|
||||
cmd.split(), capture_stdout=True, capture_stderr=True
|
||||
)
|
||||
|
||||
peer = trio.socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
await peer.bind((my_ip, 0))
|
||||
|
||||
# Tell the parent our port and get our peer's port
|
||||
await parent.send(b"%d" % peer.getsockname()[1])
|
||||
peer_port = int(await parent.recv(4096))
|
||||
await peer.connect((peer_ip, peer_port))
|
||||
|
||||
# Enter the message-forwarding loop
|
||||
async def proxy_one_way(src, dest):
|
||||
while True:
|
||||
try:
|
||||
msg = await src.recv(4096)
|
||||
except trio.ClosedResourceError:
|
||||
return
|
||||
if not msg:
|
||||
dest.close()
|
||||
return
|
||||
try:
|
||||
await dest.send(msg)
|
||||
except BrokenPipeError:
|
||||
return
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(proxy_one_way, parent, peer)
|
||||
nursery.start_soon(proxy_one_way, peer, parent)
|
||||
|
||||
|
||||
class Harness:
|
||||
def __init__(self):
|
||||
self._received = {}
|
||||
self._conn = {}
|
||||
self.failed = False
|
||||
|
||||
async def _run_peer(self, idx: int, *, task_status):
|
||||
their_ip = PEER_IP[idx]
|
||||
my_ip = ROUTER_IP[idx]
|
||||
conn, child_conn = trio.socket.socketpair(socket.AF_UNIX, socket.SOCK_SEQPACKET)
|
||||
with conn:
|
||||
try:
|
||||
process = await trio.open_process(
|
||||
[sys.executable, __file__, str(idx), str(child_conn.fileno())],
|
||||
stdin=subprocess.DEVNULL,
|
||||
pass_fds=[child_conn.fileno()],
|
||||
preexec_fn=enter_netns,
|
||||
)
|
||||
finally:
|
||||
child_conn.close()
|
||||
assert b"ok" == await conn.recv(4096)
|
||||
for cmd in (
|
||||
f"ip link add veth{idx} type veth peer netns {process.pid} name veth0",
|
||||
f"ip link set veth{idx} up",
|
||||
f"ip addr add {my_ip}/24 dev veth{idx}",
|
||||
):
|
||||
await trio.run_process(cmd.split())
|
||||
|
||||
try:
|
||||
await conn.send(b"ok")
|
||||
self._conn[idx] = conn
|
||||
task_status.started()
|
||||
retval = await process.wait()
|
||||
except BaseException:
|
||||
process.kill()
|
||||
with trio.CancelScope(shield=True):
|
||||
await process.wait()
|
||||
raise
|
||||
else:
|
||||
if retval != 0:
|
||||
raise RuntimeError(
|
||||
"peer subprocess exited with code {}".format(retval)
|
||||
)
|
||||
finally:
|
||||
await trio.run_process(f"ip link delete veth{idx}".split())
|
||||
|
||||
async def _manage_peer(self, idx: int, *, task_status):
|
||||
async with trio.open_nursery() as nursery:
|
||||
await nursery.start(self._run_peer, idx)
|
||||
packets_w, packets_r = trio.open_memory_channel(math.inf)
|
||||
self._received[idx] = packets_r
|
||||
task_status.started()
|
||||
async with packets_w:
|
||||
while True:
|
||||
msg = await self._conn[idx].recv(4096)
|
||||
if not msg:
|
||||
break
|
||||
await packets_w.send(msg)
|
||||
|
||||
@asynccontextmanager
|
||||
async def run(self):
|
||||
async with trio.open_nursery() as nursery:
|
||||
async with trio.open_nursery() as start_nursery:
|
||||
start_nursery.start_soon(nursery.start, self._manage_peer, 1)
|
||||
start_nursery.start_soon(nursery.start, self._manage_peer, 2)
|
||||
# Tell each peer about the other one's port
|
||||
await self._conn[2].send(await self._received[1].receive())
|
||||
await self._conn[1].send(await self._received[2].receive())
|
||||
yield
|
||||
self._conn[1].shutdown(socket.SHUT_WR)
|
||||
self._conn[2].shutdown(socket.SHUT_WR)
|
||||
|
||||
if not self.failed:
|
||||
for idx in (1, 2):
|
||||
async for remainder in self._received[idx]:
|
||||
raise AssertionError(
|
||||
f"Peer {idx} received unexepcted packet {remainder!r}"
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def capture_packets_to(
|
||||
self, idx: int, *, queue_num: int = -1, **options
|
||||
) -> AsyncIterator["trio.MemoryReceiveChannel[netfilterqueue.Packet]"]:
|
||||
|
||||
packets_w, packets_r = trio.open_memory_channel(math.inf)
|
||||
|
||||
def stash_packet(p):
|
||||
p.retain()
|
||||
packets_w.send_nowait(p)
|
||||
|
||||
nfq = netfilterqueue.NetfilterQueue()
|
||||
if queue_num >= 0:
|
||||
nfq.bind(queue_num, stash_packet, **options)
|
||||
else:
|
||||
for queue_num in range(16):
|
||||
try:
|
||||
nfq.bind(queue_num, stash_packet, **options)
|
||||
break
|
||||
except Exception as ex:
|
||||
last_error = ex
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Couldn't bind any netfilter queue number between 0-15"
|
||||
)
|
||||
try:
|
||||
rule = f"-d {PEER_IP[idx]} -j NFQUEUE --queue-num {queue_num}"
|
||||
await trio.run_process(f"/sbin/iptables -A FORWARD {rule}".split())
|
||||
try:
|
||||
async with packets_w, trio.open_nursery() as nursery:
|
||||
@nursery.start_soon
|
||||
async def listen_for_packets():
|
||||
while True:
|
||||
await trio.lowlevel.wait_readable(nfq.get_fd())
|
||||
nfq.run(block=False)
|
||||
yield packets_r
|
||||
nursery.cancel_scope.cancel()
|
||||
finally:
|
||||
await trio.run_process(f"/sbin/iptables -D FORWARD {rule}".split())
|
||||
finally:
|
||||
nfq.unbind()
|
||||
|
||||
async def expect(self, idx: int, *packets: bytes):
|
||||
for expected in packets:
|
||||
with trio.move_on_after(5) as scope:
|
||||
received = await self._received[idx].receive()
|
||||
if scope.cancelled_caught:
|
||||
self.failed = True
|
||||
raise AssertionError(
|
||||
f"Timeout waiting for peer {idx} to receive {expected!r}"
|
||||
)
|
||||
if received != expected:
|
||||
self.failed = True
|
||||
raise AssertionError(
|
||||
f"Expected peer {idx} to receive {expected!r} but it "
|
||||
f"received {received!r}"
|
||||
)
|
||||
|
||||
async def send(self, idx: int, *packets: bytes):
|
||||
for packet in packets:
|
||||
await self._conn[3 - idx].send(packet)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def harness() -> Harness:
|
||||
h = Harness()
|
||||
async with h.run():
|
||||
yield h
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
trio.run(peer_main, int(sys.argv[1]), int(sys.argv[2]))
|
|
@ -0,0 +1,76 @@
|
|||
import struct
|
||||
import trio
|
||||
|
||||
|
||||
async def test_comms_without_queue(harness):
|
||||
await harness.send(2, b"hello", b"world")
|
||||
await harness.expect(2, b"hello", b"world")
|
||||
await harness.send(1, b"it works?")
|
||||
await harness.expect(1, b"it works?")
|
||||
|
||||
|
||||
async def test_queue_dropping(harness):
|
||||
async def drop(packets, msg):
|
||||
async for packet in packets:
|
||||
if packet.get_payload()[28:] == msg:
|
||||
packet.drop()
|
||||
else:
|
||||
packet.accept()
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
async with harness.capture_packets_to(2) as p2, \
|
||||
harness.capture_packets_to(1) as p1:
|
||||
nursery.start_soon(drop, p2, b"two")
|
||||
nursery.start_soon(drop, p1, b"one")
|
||||
|
||||
await harness.send(2, b"one", b"two", b"three")
|
||||
await harness.send(1, b"one", b"two", b"three")
|
||||
await harness.expect(2, b"one", b"three")
|
||||
await harness.expect(1, b"two", b"three")
|
||||
|
||||
# Once we stop capturing, everything gets through again:
|
||||
await harness.send(2, b"one", b"two", b"three")
|
||||
await harness.send(1, b"one", b"two", b"three")
|
||||
await harness.expect(2, b"one", b"two", b"three")
|
||||
await harness.expect(1, b"one", b"two", b"three")
|
||||
|
||||
|
||||
async def test_rewrite_reorder(harness):
|
||||
async def munge(packets):
|
||||
def set_udp_payload(p, msg):
|
||||
data = bytearray(p.get_payload())
|
||||
old_len = len(data) - 28
|
||||
if len(msg) != old_len:
|
||||
data[2:4] = struct.pack(">H", len(msg) + 28)
|
||||
data[24:26] = struct.pack(">H", len(msg) + 8)
|
||||
# Recompute checksum too
|
||||
data[10:12] = b"\x00\x00"
|
||||
words = struct.unpack(">10H", data[:20])
|
||||
cksum = sum(words)
|
||||
while cksum >> 16:
|
||||
cksum = (cksum & 0xFFFF) + (cksum >> 16)
|
||||
data[10:12] = struct.pack(">H", cksum ^ 0xFFFF)
|
||||
# Clear UDP checksum and set payload
|
||||
data[28:] = msg
|
||||
data[26:28] = b"\x00\x00"
|
||||
p.set_payload(bytes(data))
|
||||
|
||||
async for packet in packets:
|
||||
payload = packet.get_payload()[28:]
|
||||
if payload == b"one":
|
||||
set_udp_payload(packet, b"numero uno")
|
||||
packet.accept()
|
||||
elif payload == b"two":
|
||||
two = packet
|
||||
elif payload == b"three":
|
||||
set_udp_payload(two, b"TWO")
|
||||
packet.accept()
|
||||
two.accept()
|
||||
else:
|
||||
packet.accept()
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
async with harness.capture_packets_to(2) as p2:
|
||||
nursery.start_soon(munge, p2)
|
||||
await harness.send(2, b"one", b"two", b"three", b"four")
|
||||
await harness.expect(2, b"numero uno", b"three", b"TWO", b"four")
|
Loading…
Reference in New Issue