diff --git a/binaries/coordinator/src/control.rs b/binaries/coordinator/src/control.rs index 60e59fd3..51a8219c 100644 --- a/binaries/coordinator/src/control.rs +++ b/binaries/coordinator/src/control.rs @@ -6,6 +6,7 @@ use dora_core::topics::{ControlRequest, ControlRequestReply}; use eyre::{eyre, Context}; use futures::{ future::{self, Either}, + stream::FuturesUnordered, FutureExt, Stream, StreamExt, }; use futures_concurrency::future::Race; @@ -13,20 +14,30 @@ use std::{io::ErrorKind, net::SocketAddr}; use tokio::{ net::{TcpListener, TcpStream}, sync::{mpsc, oneshot}, + task::JoinHandle, }; use tokio_stream::wrappers::ReceiverStream; pub(crate) async fn control_events( control_listen_addr: SocketAddr, + tasks: &FuturesUnordered>, ) -> eyre::Result> { let (tx, rx) = mpsc::channel(10); - tokio::spawn(listen(control_listen_addr, tx)); + let (finish_tx, mut finish_rx) = mpsc::channel(1); + tasks.push(tokio::spawn(listen(control_listen_addr, tx, finish_tx))); + tasks.push(tokio::spawn(async move { + while let Some(()) = finish_rx.recv().await {} + })); Ok(ReceiverStream::new(rx).map(Event::Control)) } -async fn listen(control_listen_addr: SocketAddr, tx: mpsc::Sender) { +async fn listen( + control_listen_addr: SocketAddr, + tx: mpsc::Sender, + _finish_tx: mpsc::Sender<()>, +) { let result = TcpListener::bind(control_listen_addr) .await .wrap_err("failed to listen for control messages"); @@ -51,7 +62,7 @@ async fn listen(control_listen_addr: SocketAddr, tx: mpsc::Sender) match connection.wrap_err("failed to connect") { Ok((connection, _)) => { let tx = tx.clone(); - tokio::spawn(handle_requests(connection, tx)); + tokio::spawn(handle_requests(connection, tx, _finish_tx.clone())); } Err(err) => { if tx.blocking_send(err.into()).is_err() { @@ -62,7 +73,11 @@ async fn listen(control_listen_addr: SocketAddr, tx: mpsc::Sender) } } -async fn handle_requests(mut connection: TcpStream, tx: mpsc::Sender) { +async fn handle_requests( + mut connection: TcpStream, + tx: mpsc::Sender, + _finish_tx: mpsc::Sender<()>, +) { loop { let next_request = tcp_receive(&mut connection).map(Either::Left); let coordinator_stopped = tx.closed().map(Either::Right); diff --git a/binaries/coordinator/src/lib.rs b/binaries/coordinator/src/lib.rs index b8a6467e..701385af 100644 --- a/binaries/coordinator/src/lib.rs +++ b/binaries/coordinator/src/lib.rs @@ -78,7 +78,7 @@ async fn start(runtime_path: &Path, tasks: &FuturesUnordered>) -> let mut daemon_events_tx = Some(daemon_events_tx); let daemon_events = ReceiverStream::new(daemon_events); - let control_events = control::control_events(control_socket_addr()) + let control_events = control::control_events(control_socket_addr(), tasks) .await .wrap_err("failed to create control events")?;