Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 30 additions & 68 deletions src/download/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,35 +163,9 @@ impl<'a> Download<'a> {
async fn download_file_(&self) -> anyhow::Result<()> {
debug!(url = %self.url, "downloading file");

// This callback will write the download to disk and optionally
// hash the contents, then forward the notification up the stack
let callback: &dyn Fn(Event<'_>) -> anyhow::Result<()> = &|msg| {
if let Event::DownloadDataReceived(data) = msg
&& let Some(h) = self.hasher.as_ref()
{
h.borrow_mut().update(data);
}

match msg {
Event::DownloadContentLengthReceived(len) => {
if let Some(status) = self.status {
status.received_length(len)
}
}
Event::DownloadDataReceived(data) => {
if let Some(status) = self.status {
status.received_data(data.len())
}
}
Event::ResumingPartialDownload => debug!("resuming partial download"),
}

Ok(())
};

// Download the file

let res = self.download_to_path(Some(callback)).await;
let res = self.download_to_path().await;

// The notification should only be sent if the download was successful (i.e. didn't timeout)
if let Some(status) = self.status {
Expand All @@ -204,8 +178,8 @@ impl<'a> Download<'a> {
res
}

async fn download_to_path(&self, callback: Option<DownloadCallback<'_>>) -> anyhow::Result<()> {
let Err(err) = self.download_impl(callback).await else {
async fn download_to_path(&self) -> anyhow::Result<()> {
let Err(err) = self.download_impl().await else {
return Ok(());
};

Expand All @@ -224,14 +198,14 @@ impl<'a> Download<'a> {
)
}

async fn download_impl(&self, callback: Option<DownloadCallback<'_>>) -> anyhow::Result<()> {
let (file, resume_from) = if self.resume {
async fn download_impl(&self) -> anyhow::Result<()> {
let (mut file, resume_from) = if self.resume {
// TODO: blocking call
let possible_partial = OpenOptions::new().read(true).open(self.path);

let downloaded_so_far = if let Ok(mut partial) = possible_partial {
if let Some(cb) = callback {
cb(Event::ResumingPartialDownload)?;
if self.status.is_some() || self.hasher.is_some() {
debug!("resuming partial download");

let mut buf = vec![0; 32768];
let mut downloaded_so_far = 0;
Expand All @@ -241,7 +215,7 @@ impl<'a> Download<'a> {
if n == 0 {
break;
}
cb(Event::DownloadDataReceived(&buf[..n]))?;
self.data_received(&buf[..n]);
}

downloaded_so_far
Expand Down Expand Up @@ -276,7 +250,6 @@ impl<'a> Download<'a> {
)
};

let file = RefCell::new(file);
let client = match self.options.tls {
#[cfg(feature = "reqwest-rustls-tls")]
Tls::Rustls => rustls_client(self.options.timeout)?,
Expand All @@ -285,34 +258,18 @@ impl<'a> Download<'a> {
};

// TODO: the sync callback will stall the async runtime if IO calls block, which is OS dependent. Rearrange.
self.execute(
resume_from,
&|event| {
if let Event::DownloadDataReceived(data) = event {
file.borrow_mut()
.write_all(data)
.context("unable to write download to disk")?;
}
match callback {
Some(cb) => cb(event),
None => Ok(()),
}
},
client,
)
.await?;
self.execute(&mut file, resume_from, client).await?;

file.borrow_mut()
.sync_data()
file.sync_data()
.context("unable to sync download to disk")?;

Ok::<(), anyhow::Error>(())
}

async fn execute(
&self,
file: &mut fs::File,
resume_from: u64,
callback: &dyn Fn(Event<'_>) -> anyhow::Result<()>,
client: &Client,
) -> anyhow::Result<()> {
// Short-circuit reqwest for the "file:" URL scheme
Expand All @@ -339,7 +296,10 @@ impl<'a> Download<'a> {
if bytes_read == 0 {
break;
}
callback(Event::DownloadDataReceived(&buffer[0..bytes_read]))?;

file.write_all(&buffer[0..bytes_read])
.context("unable to write download to disk")?;
self.data_received(&buffer[0..bytes_read]);
}

return Ok(());
Expand All @@ -366,16 +326,29 @@ impl<'a> Download<'a> {

if let Some(len) = res.content_length() {
let len = len + resume_from;
callback(Event::DownloadContentLengthReceived(len))?;
if let Some(status) = self.status {
status.received_length(len);
}
}

let mut stream = res.bytes_stream();
while let Some(item) = stream.next().await {
let bytes = item.map_err(DownloadError::Reqwest)?;
callback(Event::DownloadDataReceived(&bytes))?;
file.write_all(&bytes)
.context("unable to write download to disk")?;
self.data_received(&bytes);
}
Ok(())
}

fn data_received(&self, data: &[u8]) {
if let Some(hasher) = &self.hasher {
hasher.borrow_mut().update(data);
}
if let Some(status) = self.status {
status.received_data(data.len());
}
}
}

pub(crate) fn is_network_failure(err: &anyhow::Error) -> bool {
Expand Down Expand Up @@ -407,17 +380,6 @@ enum Tls {
NativeTls,
}

#[derive(Debug, Copy, Clone)]
enum Event<'a> {
ResumingPartialDownload,
/// Received the Content-Length of the to-be downloaded data.
DownloadContentLengthReceived(u64),
/// Received some data.
DownloadDataReceived(&'a [u8]),
}

type DownloadCallback<'a> = &'a dyn Fn(Event<'_>) -> anyhow::Result<()>;

fn client_generic() -> ClientBuilder {
Client::builder()
// HACK: set `pool_max_idle_per_host` to `0` to avoid an issue in the underlying
Expand Down
59 changes: 5 additions & 54 deletions src/download/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ mod reqwest {
use std::env::set_var;
use std::error::Error;
use std::net::TcpListener;
use std::sync::Mutex;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread;
use std::time::Duration;

Expand All @@ -30,7 +29,7 @@ mod reqwest {
use url::Url;

use super::{scrub_env, serve_file, tmp_dir, write_file};
use crate::download::{DownloadOptions, Event, Tls};
use crate::download::{DownloadOptions, Tls};

const OPTIONS: DownloadOptions = DownloadOptions {
tls: DOWNLOAD_BACKEND,
Expand Down Expand Up @@ -118,61 +117,13 @@ mod reqwest {
OPTIONS
.start(&from_url, &target_path)
.with_resume()
.download_to_path(None)
.download_to_path()
.await
.expect("Test download failed");

assert_eq!(std::fs::read_to_string(&target_path).unwrap(), "12345");
}

#[tokio::test]
Copy link
Copy Markdown
Member

@rami3l rami3l May 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to the future reviewer: the test right above is already a direct-style download resumption test, so it's safe to delete this one.

async fn callback_gets_all_data_as_if_the_download_happened_all_at_once() {
let _guard = scrub_env().await;
let tmpdir = tmp_dir();
let target_path = tmpdir.path().join("downloaded");
write_file(&target_path, "123");

let addr = serve_file(b"xxx45".to_vec(), true);

let from_url = format!("http://{addr}").parse().unwrap();

let callback_partial = AtomicBool::new(false);
let callback_len = Mutex::new(None);
let received_in_callback = Mutex::new(Vec::new());

OPTIONS
.start(&from_url, &target_path)
.with_resume()
.download_to_path(Some(&|msg| {
match msg {
Event::ResumingPartialDownload => {
assert!(!callback_partial.load(Ordering::SeqCst));
callback_partial.store(true, Ordering::SeqCst);
}
Event::DownloadContentLengthReceived(len) => {
let mut flag = callback_len.lock().unwrap();
assert!(flag.is_none());
*flag = Some(len);
}
Event::DownloadDataReceived(data) => {
for b in data.iter() {
received_in_callback.lock().unwrap().push(*b);
}
}
}

Ok(())
}))
.await
.expect("Test download failed");

assert!(callback_partial.into_inner());
assert_eq!(*callback_len.lock().unwrap(), Some(5));
let observed_bytes = received_in_callback.into_inner().unwrap();
assert_eq!(observed_bytes, vec![b'1', b'2', b'3', b'4', b'5']);
assert_eq!(std::fs::read_to_string(&target_path).unwrap(), "12345");
}

#[tokio::test]
async fn resume_partial_fails_if_server_ignores_range() {
let _guard = scrub_env().await;
Expand All @@ -186,7 +137,7 @@ mod reqwest {
OPTIONS
.start(&from_url, &target_path)
.with_resume()
.download_to_path(None)
.download_to_path()
.await
.expect_err("download should fail if server ignores range");

Expand All @@ -210,7 +161,7 @@ mod reqwest {
}
.start(&from_url, &target_path)
.with_resume()
.download_to_path(None)
.download_to_path()
.await
.expect_err("download should fail with a connect error");

Expand Down