1use reqwest::blocking::{Client as ReqwestClient, RequestBuilder};
23use rustls::RootCertStore;
24use std::io::Write;
25use std::path::PathBuf;
26use std::time::Duration;
27use thiserror::Error;
28
29pub const CLIENT_TIMEOUT_SECS: u64 = 30;
30
31const USER_AGENT: &str = concat!("github.com/OGKevin/cadmus/", env!("GIT_VERSION"));
32
33#[derive(Error, Debug)]
34pub enum HttpError {
35 #[error("Failed to build HTTP client: {0}")]
36 Build(#[from] reqwest::Error),
37}
38
39const MIN_CHUNK_SIZE: usize = 256 * 1024;
40const MAX_CHUNK_SIZE: usize = 10 * 1024 * 1024;
41const INITIAL_CHUNK_SIZE: usize = 1024 * 1024;
42const TARGET_CHUNK_SECS: f64 = CLIENT_TIMEOUT_SECS as f64 * 0.8;
44const MAX_RETRIES: usize = 3;
45
46#[derive(Error, Debug)]
48pub enum ChunkedDownloadError {
49 #[error("HTTP request error: {0}")]
50 Request(#[from] reqwest::Error),
51 #[error("IO error: {0}")]
52 Io(#[from] std::io::Error),
53}
54
55pub struct Client {
75 client: ReqwestClient,
76}
77
78impl Client {
79 pub fn new() -> Result<Self, HttpError> {
80 let root_store = build_root_store();
81
82 let tls_config = rustls::ClientConfig::builder()
83 .with_root_certificates(root_store)
84 .with_no_client_auth();
85
86 let client = ReqwestClient::builder()
87 .use_preconfigured_tls(tls_config)
88 .user_agent(USER_AGENT)
89 .timeout(Duration::from_secs(CLIENT_TIMEOUT_SECS))
90 .build()
91 .map_err(HttpError::Build)?;
92
93 tracing::debug!("HTTP client built successfully");
94 Ok(Self { client })
95 }
96
97 pub fn head(&self, url: &str) -> RequestBuilder {
98 self.client.head(url)
99 }
100
101 pub fn get(&self, url: &str) -> RequestBuilder {
102 self.client.get(url)
103 }
104
105 pub fn post(&self, url: &str) -> RequestBuilder {
106 self.client.post(url)
107 }
108
109 pub fn into_reqwest(self) -> ReqwestClient {
112 self.client
113 }
114
115 #[cfg_attr(
151 feature = "tracing",
152 tracing::instrument(skip(self, request_builder, progress_callback))
153 )]
154 pub fn download<B, F>(
155 &self,
156 url: &str,
157 total_size: u64,
158 dest: &PathBuf,
159 request_builder: B,
160 progress_callback: &mut F,
161 ) -> Result<(), ChunkedDownloadError>
162 where
163 B: Fn(&str) -> RequestBuilder,
164 F: FnMut(u64, u64),
165 {
166 progress_callback(0, total_size);
167
168 tracing::debug!(url = %url, "Downloading file");
169 tracing::debug!(path = ?dest, "Download destination");
170
171 let mut file = std::fs::File::create(dest)?;
172
173 let mut downloaded = 0u64;
174 let mut chunk_size = INITIAL_CHUNK_SIZE;
175
176 tracing::debug!(
177 initial_chunk_size = INITIAL_CHUNK_SIZE,
178 "Starting chunked download"
179 );
180
181 while downloaded < total_size {
182 let chunk_start = downloaded;
183 let chunk_end = std::cmp::min(downloaded + chunk_size as u64 - 1, total_size - 1);
184
185 tracing::debug!(
186 chunk_start,
187 chunk_end,
188 chunk_size,
189 total_size,
190 "Downloading chunk"
191 );
192
193 let start = std::time::Instant::now();
194 let chunk_data =
195 Self::download_chunk_with_retries(url, chunk_start, chunk_end, &request_builder)?;
196 let elapsed_secs = start.elapsed().as_secs_f64();
197
198 file.write_all(&chunk_data)?;
199 downloaded += chunk_data.len() as u64;
200
201 if elapsed_secs > 0.0 {
202 let throughput = chunk_data.len() as f64 / elapsed_secs;
203 chunk_size = ((throughput * TARGET_CHUNK_SECS) as usize)
204 .clamp(MIN_CHUNK_SIZE, MAX_CHUNK_SIZE);
205 tracing::debug!(
206 elapsed_secs,
207 throughput_bytes_per_sec = throughput as u64,
208 next_chunk_size = chunk_size,
209 "Adjusted chunk size"
210 );
211 }
212
213 progress_callback(downloaded, total_size);
214
215 tracing::debug!(
216 downloaded,
217 total_size,
218 progress_percent = (downloaded as f64 / total_size as f64) * 100.0,
219 "Download progress"
220 );
221 }
222
223 tracing::debug!(bytes = downloaded, "Download complete");
224 tracing::debug!(path = ?dest, "Saved file");
225
226 Ok(())
227 }
228
229 #[cfg_attr(feature = "tracing", tracing::instrument(skip(request_builder)))]
235 fn download_chunk_with_retries<B>(
236 url: &str,
237 start: u64,
238 end: u64,
239 request_builder: &B,
240 ) -> Result<Vec<u8>, ChunkedDownloadError>
241 where
242 B: Fn(&str) -> RequestBuilder,
243 {
244 let mut last_error = None;
245
246 for attempt in 1..=MAX_RETRIES {
247 match Self::download_chunk(url, start, end, request_builder) {
248 Ok(data) => {
249 if attempt > 1 {
250 tracing::debug!(
251 attempt,
252 max_retries = MAX_RETRIES,
253 "Chunk download succeeded after retry"
254 );
255 }
256 return Ok(data);
257 }
258 Err(e) => {
259 tracing::warn!(
260 attempt,
261 max_retries = MAX_RETRIES,
262 error = %e,
263 "Chunk download failed"
264 );
265 last_error = Some(e);
266
267 if attempt < MAX_RETRIES {
268 let backoff_ms = 1000 * (2u64.pow(attempt as u32 - 1));
269 tracing::debug!(backoff_ms, "Retrying after backoff");
270 std::thread::sleep(Duration::from_millis(backoff_ms));
271 }
272 }
273 }
274 }
275
276 Err(last_error.expect("MAX_RETRIES >= 1, so last_error is always set"))
277 }
278
279 #[cfg_attr(feature = "tracing", tracing::instrument(skip(request_builder)))]
285 fn download_chunk<B>(
286 url: &str,
287 start: u64,
288 end: u64,
289 request_builder: &B,
290 ) -> Result<Vec<u8>, ChunkedDownloadError>
291 where
292 B: Fn(&str) -> RequestBuilder,
293 {
294 let range_header = format!("bytes={}-{}", start, end);
295
296 let bytes = request_builder(url)
297 .header("Range", range_header)
298 .send()?
299 .error_for_status()?
300 .bytes()?;
301
302 Ok(bytes.to_vec())
303 }
304}
305
306impl Clone for Client {
307 fn clone(&self) -> Self {
308 Self {
309 client: self.client.clone(),
310 }
311 }
312}
313
314fn build_root_store() -> RootCertStore {
315 let mut store = RootCertStore::empty();
316 store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
317 store
318}