Fix TCP manager and restarts (#1556)

* Fix TCP manager and restarts

* clippy

* clippy

* clippy
This commit is contained in:
Andrea Fioraldi 2023-09-28 13:46:07 +02:00 committed by GitHub
parent 652c24cb2a
commit 19aac2fc04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -11,8 +11,10 @@ use core::{
sync::atomic::{compiler_fence, Ordering}, sync::atomic::{compiler_fence, Ordering},
}; };
use std::{ use std::{
env,
io::{ErrorKind, Read, Write}, io::{ErrorKind, Read, Write},
net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs}, net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs},
sync::Arc,
}; };
#[cfg(feature = "std")] #[cfg(feature = "std")]
@ -30,7 +32,7 @@ use serde::{de::DeserializeOwned, Deserialize, Serialize};
use tokio::{ use tokio::{
io::{AsyncReadExt, AsyncWriteExt}, io::{AsyncReadExt, AsyncWriteExt},
sync::{broadcast, mpsc}, sync::{broadcast, mpsc},
task::spawn, task::{spawn, JoinHandle},
}; };
#[cfg(feature = "std")] #[cfg(feature = "std")]
use typed_builder::TypedBuilder; use typed_builder::TypedBuilder;
@ -75,6 +77,8 @@ where
phantom: PhantomData<I>, phantom: PhantomData<I>,
} }
const UNDEFINED_CLIENT_ID: ClientId = ClientId(0xffffffff);
impl<I, MT> TcpEventBroker<I, MT> impl<I, MT> TcpEventBroker<I, MT>
where where
I: Input, I: Input,
@ -105,9 +109,10 @@ where
/// Run in the broker until all clients exit /// Run in the broker until all clients exit
#[tokio::main(flavor = "current_thread")] #[tokio::main(flavor = "current_thread")]
#[allow(clippy::too_many_lines)]
pub async fn broker_loop(&mut self) -> Result<(), Error> { pub async fn broker_loop(&mut self) -> Result<(), Error> {
let (tx_bc, rx) = broadcast::channel(128); let (tx_bc, rx) = broadcast::channel(1024);
let (tx, mut rx_mpsc) = mpsc::channel(128); let (tx, mut rx_mpsc) = mpsc::channel(1024);
let exit_cleanly_after = self.exit_cleanly_after; let exit_cleanly_after = self.exit_cleanly_after;
@ -118,36 +123,61 @@ where
let listener = tokio::net::TcpListener::from_std(listener)?; let listener = tokio::net::TcpListener::from_std(listener)?;
let tokio_broker = spawn(async move { let tokio_broker = spawn(async move {
let mut recv_handles = vec![]; let mut recv_handles: Vec<JoinHandle<_>> = vec![];
let mut receivers: Vec<Arc<tokio::sync::Mutex<broadcast::Receiver<_>>>> = vec![];
loop { loop {
let mut reached_max = false;
if let Some(max_clients) = exit_cleanly_after { if let Some(max_clients) = exit_cleanly_after {
if max_clients.get() <= recv_handles.len() { if max_clients.get() <= recv_handles.len() {
// we waited fro all the clients we wanted to see attached. Now wait for them to close their tcp connections. // we waited for all the clients we wanted to see attached. Now wait for them to close their tcp connections.
break; reached_max = true;
} }
} }
//println!("loop");
// Asynchronously wait for an inbound socket. // Asynchronously wait for an inbound socket.
let (socket, _) = listener.accept().await.expect("test"); let (socket, _) = listener.accept().await.expect("Accept failed");
let (mut read, mut write) = tokio::io::split(socket); let (mut read, mut write) = tokio::io::split(socket);
// Protocol: the new client communicate its old ClientId or -1 if new
let mut this_client_id = [0; 4];
read.read_exact(&mut this_client_id)
.await
.expect("Socket closed?");
let this_client_id = ClientId(u32::from_le_bytes(this_client_id));
let (this_client_id, is_old) = if this_client_id == UNDEFINED_CLIENT_ID {
if reached_max {
(UNDEFINED_CLIENT_ID, false) // Dumb id
} else {
// ClientIds for this broker start at 0. // ClientIds for this broker start at 0.
let this_client_id = ClientId(recv_handles.len().try_into().unwrap()); (ClientId(recv_handles.len().try_into().unwrap()), false)
}
} else {
(this_client_id, true)
};
let this_client_id_bytes = this_client_id.0.to_le_bytes(); let this_client_id_bytes = this_client_id.0.to_le_bytes();
// Send the client id for this node; // Protocol: Send the client id for this node;
write.write_all(&this_client_id_bytes).await.unwrap(); write.write_all(&this_client_id_bytes).await.unwrap();
if !is_old && reached_max {
continue;
}
let tx_inner = tx.clone(); let tx_inner = tx.clone();
let mut rx_inner = rx.resubscribe();
// Keep all handles around. let handle = async move {
recv_handles.push(spawn(async move {
// In a loop, read data from the socket and write the data back. // In a loop, read data from the socket and write the data back.
loop { loop {
let mut len_buf = [0; 4]; let mut len_buf = [0; 4];
read.read_exact(&mut len_buf).await.expect("Socket closed?"); if read.read_exact(&mut len_buf).await.is_err() {
// The socket is closed, the client is restarting
log::info!("Socket closed, client restarting");
return;
}
let mut len = u32::from_le_bytes(len_buf); let mut len = u32::from_le_bytes(len_buf);
// we forward the sender id as well, so we add 4 bytes to the message length // we forward the sender id as well, so we add 4 bytes to the message length
@ -158,26 +188,55 @@ where
let mut buf = vec![0; len as usize]; let mut buf = vec![0; len as usize];
read.read_exact(&mut buf) if read
.read_exact(&mut buf)
.await .await
.expect("failed to read data from socket"); // .expect("Failed to read data from socket"); // TODO verify if we have to handle this error
.is_err()
{
// The socket is closed, the client is restarting
log::info!("Socket closed, client restarting");
return;
}
#[cfg(feature = "tcp_debug")] #[cfg(feature = "tcp_debug")]
println!("len: {len:?} - {buf:?}"); println!("len: {len:?} - {buf:?}");
tx_inner.send(buf).await.expect("Could not send"); tx_inner.send(buf).await.expect("Could not send");
} }
})); };
let client_idx = this_client_id.0 as usize;
// Keep all handles around.
if is_old {
recv_handles[client_idx].abort();
recv_handles[client_idx] = spawn(handle);
} else {
recv_handles.push(spawn(handle));
// Get old messages only if new
let rx_inner = Arc::new(tokio::sync::Mutex::new(rx.resubscribe()));
receivers.push(rx_inner.clone());
}
let rx_inner = receivers[client_idx].clone();
// The forwarding end. No need to keep a handle to this (TODO: unless they don't quit/get stuck?) // The forwarding end. No need to keep a handle to this (TODO: unless they don't quit/get stuck?)
spawn(async move { spawn(async move {
// In a loop, read data from the socket and write the data back. // In a loop, read data from the socket and write the data back.
loop { loop {
let buf: Vec<u8> = rx_inner.recv().await.unwrap_or(vec![]); let buf: Vec<u8> = rx_inner
.lock()
.await
.recv()
.await
.expect("Could not receive");
// TODO handle full capacity, Lagged https://docs.rs/tokio/latest/tokio/sync/broadcast/error/enum.RecvError.html
#[cfg(feature = "tcp_debug")] #[cfg(feature = "tcp_debug")]
println!("{buf:?}"); println!("{buf:?}");
if buf.len() <= 4 { if buf.len() <= 4 {
eprintln!("We got no contents (or only the length) in a broadcast"); log::warn!("We got no contents (or only the length) in a broadcast");
continue; continue;
} }
@ -194,17 +253,26 @@ where
let len_buf: [u8; 4] = len.to_le_bytes(); let len_buf: [u8; 4] = len.to_le_bytes();
// Write message length // Write message length
write.write_all(&len_buf).await.expect("Writing failed"); if write.write_all(&len_buf).await.is_err() {
// The socket is closed, the client is restarting
log::info!("Socket closed, client restarting");
return;
}
// Write the rest // Write the rest
write.write_all(&buf).await.expect("Socket closed?"); if write.write_all(&buf).await.is_err() {
// The socket is closed, the client is restarting
log::info!("Socket closed, client restarting");
return;
}
} }
}); });
} }
println!("joining handles..");
/*log::info!("Joining handles..");
// wait for all clients to exit/error out // wait for all clients to exit/error out
for recv_handle in recv_handles { for recv_handle in recv_handles {
drop(recv_handle.await); drop(recv_handle.await);
} }*/
}); });
loop { loop {
@ -386,12 +454,20 @@ impl<S> TcpEventManager<S>
where where
S: UsesInput + HasExecutions + HasClientPerfMonitor, S: UsesInput + HasExecutions + HasClientPerfMonitor,
{ {
/// Create a manager from a raw TCP client /// Create a manager from a raw TCP client specifying the client id
pub fn new<A: ToSocketAddrs>(addr: &A, configuration: EventConfig) -> Result<Self, Error> { pub fn existing<A: ToSocketAddrs>(
addr: &A,
client_id: ClientId,
configuration: EventConfig,
) -> Result<Self, Error> {
let mut tcp = TcpStream::connect(addr)?; let mut tcp = TcpStream::connect(addr)?;
let mut our_client_id_buf = [0_u8; 4]; let mut our_client_id_buf = client_id.0.to_le_bytes();
tcp.read_exact(&mut our_client_id_buf).unwrap(); tcp.write_all(&our_client_id_buf)
.expect("Cannot write to the broker");
tcp.read_exact(&mut our_client_id_buf)
.expect("Cannot read from the broker");
let client_id = ClientId(u32::from_le_bytes(our_client_id_buf)); let client_id = ClientId(u32::from_le_bytes(our_client_id_buf));
println!("Our client id: {client_id:?}"); println!("Our client id: {client_id:?}");
@ -407,15 +483,49 @@ where
}) })
} }
/// Create a manager from a raw TCP client
pub fn new<A: ToSocketAddrs>(addr: &A, configuration: EventConfig) -> Result<Self, Error> {
Self::existing(addr, UNDEFINED_CLIENT_ID, configuration)
}
/// Create an TCP event manager on a port specifying the client id
///
/// If the port is not yet bound, it will act as a broker; otherwise, it
/// will act as a client.
pub fn existing_on_port(
port: u16,
client_id: ClientId,
configuration: EventConfig,
) -> Result<Self, Error> {
Self::existing(&("127.0.0.1", port), client_id, configuration)
}
/// Create an TCP event manager on a port /// Create an TCP event manager on a port
/// ///
/// If the port is not yet bound, it will act as a broker; otherwise, it /// If the port is not yet bound, it will act as a broker; otherwise, it
/// will act as a client. /// will act as a client.
#[cfg(feature = "std")]
pub fn on_port(port: u16, configuration: EventConfig) -> Result<Self, Error> { pub fn on_port(port: u16, configuration: EventConfig) -> Result<Self, Error> {
Self::new(&("127.0.0.1", port), configuration) Self::new(&("127.0.0.1", port), configuration)
} }
/// Create an TCP event manager on a port specifying the client id from env
///
/// If the port is not yet bound, it will act as a broker; otherwise, it
/// will act as a client.
pub fn existing_from_env<A: ToSocketAddrs>(
addr: &A,
env_name: &str,
configuration: EventConfig,
) -> Result<Self, Error> {
let this_id = ClientId(str::parse::<u32>(&env::var(env_name)?)?);
Self::existing(addr, this_id, configuration)
}
/// Write the client id for a client [`EventManager`] to env vars
pub fn to_env(&self, env_name: &str) {
env::set_var(env_name, format!("{}", self.client_id.0));
}
// Handle arriving events in the client // Handle arriving events in the client
#[allow(clippy::unused_self)] #[allow(clippy::unused_self)]
fn handle_in_client<E, Z>( fn handle_in_client<E, Z>(
@ -731,8 +841,11 @@ where
fn on_restart(&mut self, state: &mut S) -> Result<(), Error> { fn on_restart(&mut self, state: &mut S) -> Result<(), Error> {
// First, reset the page to 0 so the next iteration can read read from the beginning of this page // First, reset the page to 0 so the next iteration can read read from the beginning of this page
self.staterestorer.reset(); self.staterestorer.reset();
self.staterestorer self.staterestorer.save(&if self.save_state {
.save(&if self.save_state { Some(state) } else { None })?; Some((state, self.tcp_mgr.client_id))
} else {
None
})?;
self.await_restart_safe(); self.await_restart_safe();
Ok(()) Ok(())
} }
@ -938,7 +1051,7 @@ where
}; };
// We get here if we are on Unix, or we are a broker on Windows (or without forks). // We get here if we are on Unix, or we are a broker on Windows (or without forks).
let (_mgr, core_id) = match self.kind { let (mgr, core_id) = match self.kind {
ManagerKind::Any => { ManagerKind::Any => {
let connection = create_nonblocking_listener(("127.0.0.1", self.broker_port)); let connection = create_nonblocking_listener(("127.0.0.1", self.broker_port));
match connection { match connection {
@ -994,7 +1107,7 @@ where
} }
// We are the fuzzer respawner in a tcp client // We are the fuzzer respawner in a tcp client
//mgr.to_env(_ENV_FUZZER_BROKER_CLIENT_INITIAL); mgr.to_env(_ENV_FUZZER_BROKER_CLIENT_INITIAL);
// First, create a channel from the current fuzzer to the next to store state between restarts. // First, create a channel from the current fuzzer to the next to store state between restarts.
#[cfg(unix)] #[cfg(unix)]
@ -1030,6 +1143,7 @@ where
// Client->parent loop // Client->parent loop
loop { loop {
log::info!("Spawning next client (id {ctr})"); log::info!("Spawning next client (id {ctr})");
println!("Spawning next client (id {ctr}) {core_id:?}");
// On Unix, we fork (when fork feature is enabled) // On Unix, we fork (when fork feature is enabled)
#[cfg(all(unix, feature = "fork"))] #[cfg(all(unix, feature = "fork"))]
@ -1091,11 +1205,15 @@ where
} }
// If we're restarting, deserialize the old state. // If we're restarting, deserialize the old state.
let (state, mut mgr) = if let Some(state_opt) = staterestorer.restore()? { let (state, mut mgr) = if let Some((state_opt, this_id)) = staterestorer.restore()? {
( (
state_opt, state_opt,
TcpRestartingEventManager::with_save_state( TcpRestartingEventManager::with_save_state(
TcpEventManager::on_port(self.broker_port, self.configuration)?, TcpEventManager::existing_on_port(
self.broker_port,
this_id,
self.configuration,
)?,
staterestorer, staterestorer,
self.serialize_state, self.serialize_state,
), ),
@ -1103,7 +1221,11 @@ where
} else { } else {
log::info!("First run. Let's set it all up"); log::info!("First run. Let's set it all up");
// Mgr to send and receive msgs from/to all other fuzzer instances // Mgr to send and receive msgs from/to all other fuzzer instances
let mgr = TcpEventManager::<S>::on_port(self.broker_port, self.configuration)?; let mgr = TcpEventManager::<S>::existing_from_env(
&("127.0.0.1", self.broker_port),
_ENV_FUZZER_BROKER_CLIENT_INITIAL,
self.configuration,
)?;
( (
None, None,