From 3775a3c08b3f42e4a51e597caf65a687c749564f Mon Sep 17 00:00:00 2001 From: haixuanTao Date: Wed, 9 Oct 2024 12:01:50 +0200 Subject: [PATCH] Make downloading file use `content_dosposition` header or the filename within the url to get the path avoiding to create confusion on the name of the file. --- binaries/daemon/src/spawn.rs | 10 ++-- binaries/runtime/src/operator/python.rs | 7 +-- binaries/runtime/src/operator/shared_lib.rs | 13 ++---- libraries/extensions/download/src/lib.rs | 52 ++++++++++++++------- 4 files changed, 43 insertions(+), 39 deletions(-) diff --git a/binaries/daemon/src/spawn.rs b/binaries/daemon/src/spawn.rs index 78302f0d..87eca5a3 100644 --- a/binaries/daemon/src/spawn.rs +++ b/binaries/daemon/src/spawn.rs @@ -27,7 +27,6 @@ use dora_node_api::{ }; use eyre::{ContextCompat, WrapErr}; use std::{ - env::consts::EXE_EXTENSION, path::{Path, PathBuf}, process::Stdio, sync::Arc, @@ -101,13 +100,10 @@ pub async fn spawn_node( source => { let resolved_path = if source_is_url(source) { // try to download the shared library - let target_path = Path::new("build") - .join(node_id.to_string()) - .with_extension(EXE_EXTENSION); - download_file(source, &target_path) + let target_dir = Path::new("build"); + download_file(source, &target_dir) .await - .wrap_err("failed to download custom node")?; - target_path.clone() + .wrap_err("failed to download custom node")? } else { resolve_path(source, working_dir).wrap_err_with(|| { format!("failed to resolve node source `{}`", source) diff --git a/binaries/runtime/src/operator/python.rs b/binaries/runtime/src/operator/python.rs index c511f463..4885dc37 100644 --- a/binaries/runtime/src/operator/python.rs +++ b/binaries/runtime/src/operator/python.rs @@ -42,16 +42,13 @@ pub fn run( dataflow_descriptor: &Descriptor, ) -> eyre::Result<()> { let path = if source_is_url(&python_source.source) { - let target_path = Path::new("build") - .join(node_id.to_string()) - .join(format!("{}.py", operator_id)); + let target_path = Path::new("build"); // try to download the shared library let rt = tokio::runtime::Builder::new_current_thread() .enable_all() .build()?; rt.block_on(download_file(&python_source.source, &target_path)) - .wrap_err("failed to download Python operator")?; - target_path + .wrap_err("failed to download Python operator")? } else { Path::new(&python_source.source).to_owned() }; diff --git a/binaries/runtime/src/operator/shared_lib.rs b/binaries/runtime/src/operator/shared_lib.rs index 70fccff4..e7ec3068 100644 --- a/binaries/runtime/src/operator/shared_lib.rs +++ b/binaries/runtime/src/operator/shared_lib.rs @@ -27,26 +27,21 @@ use tokio::sync::{mpsc::Sender, oneshot}; use tracing::{field, span}; pub fn run( - node_id: &NodeId, - operator_id: &OperatorId, + _node_id: &NodeId, + _operator_id: &OperatorId, source: &str, events_tx: Sender, incoming_events: flume::Receiver, init_done: oneshot::Sender>, ) -> eyre::Result<()> { let path = if source_is_url(source) { - let target_path = adjust_shared_library_path( - &Path::new("build") - .join(node_id.to_string()) - .join(operator_id.to_string()), - )?; + let target_path = &Path::new("build"); // try to download the shared library let rt = tokio::runtime::Builder::new_current_thread() .enable_all() .build()?; rt.block_on(download_file(source, &target_path)) - .wrap_err("failed to download shared library operator")?; - target_path + .wrap_err("failed to download shared library operator")? } else { adjust_shared_library_path(Path::new(source))? }; diff --git a/libraries/extensions/download/src/lib.rs b/libraries/extensions/download/src/lib.rs index b0843c83..1e49fee3 100644 --- a/libraries/extensions/download/src/lib.rs +++ b/libraries/extensions/download/src/lib.rs @@ -1,35 +1,51 @@ -use eyre::Context; +use eyre::{Context, ContextCompat}; #[cfg(unix)] use std::os::unix::prelude::PermissionsExt; -use std::path::Path; +use std::path::{Path, PathBuf}; use tokio::io::AsyncWriteExt; -use tracing::info; -pub async fn download_file(url: T, target_path: &Path) -> Result<(), eyre::ErrReport> -where - T: reqwest::IntoUrl + std::fmt::Display + Copy, -{ - if target_path.exists() { - info!("Using cache: {:?}", target_path.to_str()); - return Ok(()); +fn get_filename(response: &reqwest::Response) -> Option { + if let Some(content_disposition) = response.headers().get("content-disposition") { + if let Ok(filename) = content_disposition.to_str() { + if let Some(name) = filename.split("filename=").nth(1) { + return Some(name.trim_matches('"').to_string()); + } + } } - if let Some(parent) = target_path.parent() { - tokio::fs::create_dir_all(parent) - .await - .wrap_err("failed to create parent folder")?; + // If Content-Disposition header is not available, extract from URL + let path = Path::new(response.url().as_str()); + if let Some(name) = path.file_name() { + if let Some(filename) = name.to_str() { + return Some(filename.to_string()); + } } + None +} + +pub async fn download_file(url: T, target_dir: &Path) -> Result +where + T: reqwest::IntoUrl + std::fmt::Display + Copy, +{ + tokio::fs::create_dir_all(&target_dir) + .await + .wrap_err("failed to create parent folder")?; + let response = reqwest::get(url) .await - .wrap_err_with(|| format!("failed to request operator from `{url}`"))? + .wrap_err_with(|| format!("failed to request operator from `{url}`"))?; + + let filename = get_filename(&response).context("Could not find a filename")?; + let bytes = response .bytes() .await .wrap_err("failed to read operator from `{uri}`")?; - let mut file = tokio::fs::File::create(target_path) + let path = target_dir.join(filename); + let mut file = tokio::fs::File::create(&path) .await .wrap_err("failed to create target file")?; - file.write_all(&response) + file.write_all(&bytes) .await .wrap_err("failed to write downloaded operator to file")?; file.sync_all().await.wrap_err("failed to `sync_all`")?; @@ -39,5 +55,5 @@ where .await .wrap_err("failed to make downloaded file executable")?; - Ok(()) + Ok(path.to_path_buf()) }