diff --git a/Cargo.lock b/Cargo.lock index 7dd0050f..5ec3dcd8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -960,6 +960,7 @@ dependencies = [ "bincode", "clap 3.2.20", "communication-layer-request-reply", + "ctrlc", "dora-core", "dora-download", "dora-node-api", diff --git a/binaries/coordinator/Cargo.toml b/binaries/coordinator/Cargo.toml index c1ee3e18..c75bc79b 100644 --- a/binaries/coordinator/Cargo.toml +++ b/binaries/coordinator/Cargo.toml @@ -30,3 +30,4 @@ dora-download = { path = "../../libraries/extensions/download" } which = "4.3.0" communication-layer-request-reply = { path = "../../libraries/communication-layer/request-reply" } thiserror = "1.0.37" +ctrlc = "3.2.5" diff --git a/binaries/coordinator/src/lib.rs b/binaries/coordinator/src/lib.rs index db87a046..c45a977b 100644 --- a/binaries/coordinator/src/lib.rs +++ b/binaries/coordinator/src/lib.rs @@ -21,7 +21,7 @@ use std::{ path::{Path, PathBuf}, time::Duration, }; -use tokio::net::TcpStream; +use tokio::{net::TcpStream, sync::mpsc}; use tokio_stream::wrappers::{ReceiverStream, TcpListenerStream}; use uuid::Uuid; @@ -55,6 +55,9 @@ pub async fn run(args: Args) -> eyre::Result<()> { } async fn start(runtime_path: &Path) -> eyre::Result<()> { + let (ctrlc_tx, ctrlc_rx) = set_up_ctrlc_handler()?; + let mut ctrlc_tx_handle = Some(ctrlc_tx); + let listener = listener::create_listener(DORA_COORDINATOR_PORT_DEFAULT).await?; let (new_daemon_connections, new_daemon_connections_abort) = futures::stream::abortable(TcpListenerStream::new(listener).map(|c| { @@ -76,12 +79,14 @@ async fn start(runtime_path: &Path) -> eyre::Result<()> { let daemon_watchdog_interval = tokio_stream::wrappers::IntervalStream::new(tokio::time::interval(Duration::from_secs(1))) .map(|_| Event::DaemonWatchdogInterval); + let ctrlc_events = ReceiverStream::new(ctrlc_rx); let mut events = ( new_daemon_connections, daemon_events, control_events, daemon_watchdog_interval, + ctrlc_events, ) .merge(); @@ -249,20 +254,15 @@ async fn start(runtime_path: &Path) -> eyre::Result<()> { ControlRequest::Destroy => { tracing::info!("Received destroy command"); - control_events_abort.abort(); - - // stop all running dataflows - for &uuid in running_dataflows.keys() { - stop_dataflow(&running_dataflows, uuid, &mut daemon_connections) - .await?; - } - - // destroy all connected daemons - destroy_daemons(&mut daemon_connections).await?; - - // prevent the creation of new daemon connections - new_daemon_connections_abort.abort(); - daemon_events_tx = None; + handle_destroy( + &control_events_abort, + &running_dataflows, + &mut daemon_connections, + &new_daemon_connections_abort, + &mut daemon_events_tx, + &mut ctrlc_tx_handle, + ) + .await?; b"ok".as_slice().into() } @@ -313,6 +313,18 @@ async fn start(runtime_path: &Path) -> eyre::Result<()> { } } } + Event::CtrlC => { + tracing::info!("Destroying coordinator after receiving Ctrl-C signal"); + handle_destroy( + &control_events_abort, + &running_dataflows, + &mut daemon_connections, + &new_daemon_connections_abort, + &mut daemon_events_tx, + &mut ctrlc_tx_handle, + ) + .await?; + } } } @@ -321,6 +333,49 @@ async fn start(runtime_path: &Path) -> eyre::Result<()> { Ok(()) } +fn set_up_ctrlc_handler() -> Result<(mpsc::Sender, mpsc::Receiver), eyre::ErrReport> { + let (ctrlc_tx, ctrlc_rx) = mpsc::channel(1); + + let ctrlc_tx_weak = ctrlc_tx.downgrade(); + let mut ctrlc_sent = false; + ctrlc::set_handler(move || { + if ctrlc_sent { + tracing::warn!("received second ctrlc signal -> aborting immediately"); + std::process::abort(); + } else { + tracing::info!("received ctrlc signal"); + if let Some(ctrlc_tx) = ctrlc_tx_weak.upgrade() { + if ctrlc_tx.blocking_send(Event::CtrlC).is_err() { + tracing::error!("failed to report ctrl-c event to dora-coordinator"); + } + } + ctrlc_sent = true; + } + }) + .wrap_err("failed to set ctrl-c handler")?; + + Ok((ctrlc_tx, ctrlc_rx)) +} + +async fn handle_destroy( + control_events_abort: &futures::stream::AbortHandle, + running_dataflows: &HashMap, + daemon_connections: &mut HashMap, + new_daemon_connections_abort: &futures::stream::AbortHandle, + daemon_events_tx: &mut Option>, + ctrlc_tx: &mut Option>, +) -> Result<(), eyre::ErrReport> { + control_events_abort.abort(); + for &uuid in running_dataflows.keys() { + stop_dataflow(running_dataflows, uuid, daemon_connections).await?; + } + destroy_daemons(daemon_connections).await?; + new_daemon_connections_abort.abort(); + *daemon_events_tx = None; + *ctrlc_tx = None; + Ok(()) +} + async fn send_watchdog_message(connection: &mut TcpStream) -> eyre::Result<()> { let message = serde_json::to_vec(&DaemonCoordinatorEvent::Watchdog).unwrap(); @@ -447,7 +502,9 @@ pub enum Event { Control(ControlEvent), Daemon(DaemonEvent), DaemonWatchdogInterval, + CtrlC, } + impl Event { /// Whether this event should be logged. #[allow(clippy::match_like_matches_macro)]