Browse Source

Make downloading file use `content_dosposition` header or the filename within the url to get the path avoiding to create confusion on the name of the file.

tags/v0.3.7rc0
haixuanTao 1 year ago
parent
commit
3775a3c08b
4 changed files with 43 additions and 39 deletions
  1. +3
    -7
      binaries/daemon/src/spawn.rs
  2. +2
    -5
      binaries/runtime/src/operator/python.rs
  3. +4
    -9
      binaries/runtime/src/operator/shared_lib.rs
  4. +34
    -18
      libraries/extensions/download/src/lib.rs

+ 3
- 7
binaries/daemon/src/spawn.rs View File

@@ -27,7 +27,6 @@ use dora_node_api::{
};
use eyre::{ContextCompat, WrapErr};
use std::{
env::consts::EXE_EXTENSION,
path::{Path, PathBuf},
process::Stdio,
sync::Arc,
@@ -101,13 +100,10 @@ pub async fn spawn_node(
source => {
let resolved_path = if source_is_url(source) {
// try to download the shared library
let target_path = Path::new("build")
.join(node_id.to_string())
.with_extension(EXE_EXTENSION);
download_file(source, &target_path)
let target_dir = Path::new("build");
download_file(source, &target_dir)
.await
.wrap_err("failed to download custom node")?;
target_path.clone()
.wrap_err("failed to download custom node")?
} else {
resolve_path(source, working_dir).wrap_err_with(|| {
format!("failed to resolve node source `{}`", source)


+ 2
- 5
binaries/runtime/src/operator/python.rs View File

@@ -42,16 +42,13 @@ pub fn run(
dataflow_descriptor: &Descriptor,
) -> eyre::Result<()> {
let path = if source_is_url(&python_source.source) {
let target_path = Path::new("build")
.join(node_id.to_string())
.join(format!("{}.py", operator_id));
let target_path = Path::new("build");
// try to download the shared library
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
rt.block_on(download_file(&python_source.source, &target_path))
.wrap_err("failed to download Python operator")?;
target_path
.wrap_err("failed to download Python operator")?
} else {
Path::new(&python_source.source).to_owned()
};


+ 4
- 9
binaries/runtime/src/operator/shared_lib.rs View File

@@ -27,26 +27,21 @@ use tokio::sync::{mpsc::Sender, oneshot};
use tracing::{field, span};

pub fn run(
node_id: &NodeId,
operator_id: &OperatorId,
_node_id: &NodeId,
_operator_id: &OperatorId,
source: &str,
events_tx: Sender<OperatorEvent>,
incoming_events: flume::Receiver<Event>,
init_done: oneshot::Sender<Result<()>>,
) -> eyre::Result<()> {
let path = if source_is_url(source) {
let target_path = adjust_shared_library_path(
&Path::new("build")
.join(node_id.to_string())
.join(operator_id.to_string()),
)?;
let target_path = &Path::new("build");
// try to download the shared library
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
rt.block_on(download_file(source, &target_path))
.wrap_err("failed to download shared library operator")?;
target_path
.wrap_err("failed to download shared library operator")?
} else {
adjust_shared_library_path(Path::new(source))?
};


+ 34
- 18
libraries/extensions/download/src/lib.rs View File

@@ -1,35 +1,51 @@
use eyre::Context;
use eyre::{Context, ContextCompat};
#[cfg(unix)]
use std::os::unix::prelude::PermissionsExt;
use std::path::Path;
use std::path::{Path, PathBuf};
use tokio::io::AsyncWriteExt;
use tracing::info;

pub async fn download_file<T>(url: T, target_path: &Path) -> Result<(), eyre::ErrReport>
where
T: reqwest::IntoUrl + std::fmt::Display + Copy,
{
if target_path.exists() {
info!("Using cache: {:?}", target_path.to_str());
return Ok(());
fn get_filename(response: &reqwest::Response) -> Option<String> {
if let Some(content_disposition) = response.headers().get("content-disposition") {
if let Ok(filename) = content_disposition.to_str() {
if let Some(name) = filename.split("filename=").nth(1) {
return Some(name.trim_matches('"').to_string());
}
}
}

if let Some(parent) = target_path.parent() {
tokio::fs::create_dir_all(parent)
.await
.wrap_err("failed to create parent folder")?;
// If Content-Disposition header is not available, extract from URL
let path = Path::new(response.url().as_str());
if let Some(name) = path.file_name() {
if let Some(filename) = name.to_str() {
return Some(filename.to_string());
}
}

None
}

pub async fn download_file<T>(url: T, target_dir: &Path) -> Result<PathBuf, eyre::ErrReport>
where
T: reqwest::IntoUrl + std::fmt::Display + Copy,
{
tokio::fs::create_dir_all(&target_dir)
.await
.wrap_err("failed to create parent folder")?;

let response = reqwest::get(url)
.await
.wrap_err_with(|| format!("failed to request operator from `{url}`"))?
.wrap_err_with(|| format!("failed to request operator from `{url}`"))?;

let filename = get_filename(&response).context("Could not find a filename")?;
let bytes = response
.bytes()
.await
.wrap_err("failed to read operator from `{uri}`")?;
let mut file = tokio::fs::File::create(target_path)
let path = target_dir.join(filename);
let mut file = tokio::fs::File::create(&path)
.await
.wrap_err("failed to create target file")?;
file.write_all(&response)
file.write_all(&bytes)
.await
.wrap_err("failed to write downloaded operator to file")?;
file.sync_all().await.wrap_err("failed to `sync_all`")?;
@@ -39,5 +55,5 @@ where
.await
.wrap_err("failed to make downloaded file executable")?;

Ok(())
Ok(path.to_path_buf())
}

Loading…
Cancel
Save