diff --git a/libafl/src/events/tcp.rs b/libafl/src/events/tcp.rs index b484bfa279..ff7e9a25d9 100644 --- a/libafl/src/events/tcp.rs +++ b/libafl/src/events/tcp.rs @@ -11,8 +11,10 @@ use core::{ sync::atomic::{compiler_fence, Ordering}, }; use std::{ + env, io::{ErrorKind, Read, Write}, net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs}, + sync::Arc, }; #[cfg(feature = "std")] @@ -30,7 +32,7 @@ use serde::{de::DeserializeOwned, Deserialize, Serialize}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, sync::{broadcast, mpsc}, - task::spawn, + task::{spawn, JoinHandle}, }; #[cfg(feature = "std")] use typed_builder::TypedBuilder; @@ -75,6 +77,8 @@ where phantom: PhantomData, } +const UNDEFINED_CLIENT_ID: ClientId = ClientId(0xffffffff); + impl TcpEventBroker where I: Input, @@ -105,9 +109,10 @@ where /// Run in the broker until all clients exit #[tokio::main(flavor = "current_thread")] + #[allow(clippy::too_many_lines)] pub async fn broker_loop(&mut self) -> Result<(), Error> { - let (tx_bc, rx) = broadcast::channel(128); - let (tx, mut rx_mpsc) = mpsc::channel(128); + let (tx_bc, rx) = broadcast::channel(1024); + let (tx, mut rx_mpsc) = mpsc::channel(1024); let exit_cleanly_after = self.exit_cleanly_after; @@ -118,36 +123,61 @@ where let listener = tokio::net::TcpListener::from_std(listener)?; let tokio_broker = spawn(async move { - let mut recv_handles = vec![]; + let mut recv_handles: Vec> = vec![]; + let mut receivers: Vec>>> = vec![]; loop { + let mut reached_max = false; if let Some(max_clients) = exit_cleanly_after { if max_clients.get() <= recv_handles.len() { - // we waited fro all the clients we wanted to see attached. Now wait for them to close their tcp connections. - break; + // we waited for all the clients we wanted to see attached. Now wait for them to close their tcp connections. + reached_max = true; } } - //println!("loop"); // Asynchronously wait for an inbound socket. - let (socket, _) = listener.accept().await.expect("test"); + let (socket, _) = listener.accept().await.expect("Accept failed"); let (mut read, mut write) = tokio::io::split(socket); - // ClientIds for this broker start at 0. - let this_client_id = ClientId(recv_handles.len().try_into().unwrap()); + + // Protocol: the new client communicate its old ClientId or -1 if new + let mut this_client_id = [0; 4]; + read.read_exact(&mut this_client_id) + .await + .expect("Socket closed?"); + let this_client_id = ClientId(u32::from_le_bytes(this_client_id)); + + let (this_client_id, is_old) = if this_client_id == UNDEFINED_CLIENT_ID { + if reached_max { + (UNDEFINED_CLIENT_ID, false) // Dumb id + } else { + // ClientIds for this broker start at 0. + (ClientId(recv_handles.len().try_into().unwrap()), false) + } + } else { + (this_client_id, true) + }; + let this_client_id_bytes = this_client_id.0.to_le_bytes(); - // Send the client id for this node; + // Protocol: Send the client id for this node; write.write_all(&this_client_id_bytes).await.unwrap(); + if !is_old && reached_max { + continue; + } + let tx_inner = tx.clone(); - let mut rx_inner = rx.resubscribe(); - // Keep all handles around. - recv_handles.push(spawn(async move { + + let handle = async move { // In a loop, read data from the socket and write the data back. loop { let mut len_buf = [0; 4]; - read.read_exact(&mut len_buf).await.expect("Socket closed?"); + if read.read_exact(&mut len_buf).await.is_err() { + // The socket is closed, the client is restarting + log::info!("Socket closed, client restarting"); + return; + } let mut len = u32::from_le_bytes(len_buf); // we forward the sender id as well, so we add 4 bytes to the message length @@ -158,26 +188,55 @@ where let mut buf = vec![0; len as usize]; - read.read_exact(&mut buf) + if read + .read_exact(&mut buf) .await - .expect("failed to read data from socket"); + // .expect("Failed to read data from socket"); // TODO verify if we have to handle this error + .is_err() + { + // The socket is closed, the client is restarting + log::info!("Socket closed, client restarting"); + return; + } #[cfg(feature = "tcp_debug")] println!("len: {len:?} - {buf:?}"); tx_inner.send(buf).await.expect("Could not send"); } - })); + }; + + let client_idx = this_client_id.0 as usize; + + // Keep all handles around. + if is_old { + recv_handles[client_idx].abort(); + recv_handles[client_idx] = spawn(handle); + } else { + recv_handles.push(spawn(handle)); + // Get old messages only if new + let rx_inner = Arc::new(tokio::sync::Mutex::new(rx.resubscribe())); + receivers.push(rx_inner.clone()); + } + + let rx_inner = receivers[client_idx].clone(); + // The forwarding end. No need to keep a handle to this (TODO: unless they don't quit/get stuck?) spawn(async move { // In a loop, read data from the socket and write the data back. loop { - let buf: Vec = rx_inner.recv().await.unwrap_or(vec![]); + let buf: Vec = rx_inner + .lock() + .await + .recv() + .await + .expect("Could not receive"); + // TODO handle full capacity, Lagged https://docs.rs/tokio/latest/tokio/sync/broadcast/error/enum.RecvError.html #[cfg(feature = "tcp_debug")] println!("{buf:?}"); if buf.len() <= 4 { - eprintln!("We got no contents (or only the length) in a broadcast"); + log::warn!("We got no contents (or only the length) in a broadcast"); continue; } @@ -194,17 +253,26 @@ where let len_buf: [u8; 4] = len.to_le_bytes(); // Write message length - write.write_all(&len_buf).await.expect("Writing failed"); + if write.write_all(&len_buf).await.is_err() { + // The socket is closed, the client is restarting + log::info!("Socket closed, client restarting"); + return; + } // Write the rest - write.write_all(&buf).await.expect("Socket closed?"); + if write.write_all(&buf).await.is_err() { + // The socket is closed, the client is restarting + log::info!("Socket closed, client restarting"); + return; + } } }); } - println!("joining handles.."); + + /*log::info!("Joining handles.."); // wait for all clients to exit/error out for recv_handle in recv_handles { drop(recv_handle.await); - } + }*/ }); loop { @@ -386,12 +454,20 @@ impl TcpEventManager where S: UsesInput + HasExecutions + HasClientPerfMonitor, { - /// Create a manager from a raw TCP client - pub fn new(addr: &A, configuration: EventConfig) -> Result { + /// Create a manager from a raw TCP client specifying the client id + pub fn existing( + addr: &A, + client_id: ClientId, + configuration: EventConfig, + ) -> Result { let mut tcp = TcpStream::connect(addr)?; - let mut our_client_id_buf = [0_u8; 4]; - tcp.read_exact(&mut our_client_id_buf).unwrap(); + let mut our_client_id_buf = client_id.0.to_le_bytes(); + tcp.write_all(&our_client_id_buf) + .expect("Cannot write to the broker"); + + tcp.read_exact(&mut our_client_id_buf) + .expect("Cannot read from the broker"); let client_id = ClientId(u32::from_le_bytes(our_client_id_buf)); println!("Our client id: {client_id:?}"); @@ -407,15 +483,49 @@ where }) } + /// Create a manager from a raw TCP client + pub fn new(addr: &A, configuration: EventConfig) -> Result { + Self::existing(addr, UNDEFINED_CLIENT_ID, configuration) + } + + /// Create an TCP event manager on a port specifying the client id + /// + /// If the port is not yet bound, it will act as a broker; otherwise, it + /// will act as a client. + pub fn existing_on_port( + port: u16, + client_id: ClientId, + configuration: EventConfig, + ) -> Result { + Self::existing(&("127.0.0.1", port), client_id, configuration) + } + /// Create an TCP event manager on a port /// /// If the port is not yet bound, it will act as a broker; otherwise, it /// will act as a client. - #[cfg(feature = "std")] pub fn on_port(port: u16, configuration: EventConfig) -> Result { Self::new(&("127.0.0.1", port), configuration) } + /// Create an TCP event manager on a port specifying the client id from env + /// + /// If the port is not yet bound, it will act as a broker; otherwise, it + /// will act as a client. + pub fn existing_from_env( + addr: &A, + env_name: &str, + configuration: EventConfig, + ) -> Result { + let this_id = ClientId(str::parse::(&env::var(env_name)?)?); + Self::existing(addr, this_id, configuration) + } + + /// Write the client id for a client [`EventManager`] to env vars + pub fn to_env(&self, env_name: &str) { + env::set_var(env_name, format!("{}", self.client_id.0)); + } + // Handle arriving events in the client #[allow(clippy::unused_self)] fn handle_in_client( @@ -731,8 +841,11 @@ where fn on_restart(&mut self, state: &mut S) -> Result<(), Error> { // First, reset the page to 0 so the next iteration can read read from the beginning of this page self.staterestorer.reset(); - self.staterestorer - .save(&if self.save_state { Some(state) } else { None })?; + self.staterestorer.save(&if self.save_state { + Some((state, self.tcp_mgr.client_id)) + } else { + None + })?; self.await_restart_safe(); Ok(()) } @@ -938,7 +1051,7 @@ where }; // We get here if we are on Unix, or we are a broker on Windows (or without forks). - let (_mgr, core_id) = match self.kind { + let (mgr, core_id) = match self.kind { ManagerKind::Any => { let connection = create_nonblocking_listener(("127.0.0.1", self.broker_port)); match connection { @@ -994,7 +1107,7 @@ where } // We are the fuzzer respawner in a tcp client - //mgr.to_env(_ENV_FUZZER_BROKER_CLIENT_INITIAL); + mgr.to_env(_ENV_FUZZER_BROKER_CLIENT_INITIAL); // First, create a channel from the current fuzzer to the next to store state between restarts. #[cfg(unix)] @@ -1030,6 +1143,7 @@ where // Client->parent loop loop { log::info!("Spawning next client (id {ctr})"); + println!("Spawning next client (id {ctr}) {core_id:?}"); // On Unix, we fork (when fork feature is enabled) #[cfg(all(unix, feature = "fork"))] @@ -1091,11 +1205,15 @@ where } // If we're restarting, deserialize the old state. - let (state, mut mgr) = if let Some(state_opt) = staterestorer.restore()? { + let (state, mut mgr) = if let Some((state_opt, this_id)) = staterestorer.restore()? { ( state_opt, TcpRestartingEventManager::with_save_state( - TcpEventManager::on_port(self.broker_port, self.configuration)?, + TcpEventManager::existing_on_port( + self.broker_port, + this_id, + self.configuration, + )?, staterestorer, self.serialize_state, ), @@ -1103,7 +1221,11 @@ where } else { log::info!("First run. Let's set it all up"); // Mgr to send and receive msgs from/to all other fuzzer instances - let mgr = TcpEventManager::::on_port(self.broker_port, self.configuration)?; + let mgr = TcpEventManager::::existing_from_env( + &("127.0.0.1", self.broker_port), + _ENV_FUZZER_BROKER_CLIENT_INITIAL, + self.configuration, + )?; ( None,