diff --git a/Cargo.lock b/Cargo.lock index 474601b..99fd3d5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -340,7 +340,7 @@ checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" [[package]] name = "corgea" -version = "1.8.7" +version = "1.8.8" dependencies = [ "chrono", "clap", diff --git a/Cargo.toml b/Cargo.toml index 608ffbd..35a859a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "corgea" -version = "1.8.7" +version = "1.8.8" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/src/scan.rs b/src/scan.rs index 184dbdd..3265cd8 100644 --- a/src/scan.rs +++ b/src/scan.rs @@ -186,14 +186,15 @@ pub fn upload_scan(config: &Config, paths: Vec, scanner: String, input: let mut success = false; while attempts < 3 && !success { - let form = reqwest::blocking::multipart::Form::new() - .file("file", fp) - .expect("Failed to read file"); - debug(&format!("POST: {}", src_upload_url)); - let res = client.post(&src_upload_url) - .multipart(form) - .send(); + let res = utils::api::retry_on_network_error("file upload", || { + let form = reqwest::blocking::multipart::Form::new() + .file("file", fp) + .expect("Failed to read file"); + client.post(&src_upload_url) + .multipart(form) + .send() + }); match res { Ok(response) => { @@ -211,8 +212,9 @@ pub fn upload_scan(config: &Config, paths: Vec, scanner: String, input: } } Err(e) => { - eprintln!("Failed to send request: {}", e); - std::process::exit(1); + upload_error_count += 1; + eprintln!("Failed to upload file {} after network retries: {}", path, e); + break; } } } @@ -258,12 +260,14 @@ pub fn upload_scan(config: &Config, paths: Vec, scanner: String, input: for (index, chunk) in input_bytes.chunks(chunk_size).enumerate() { debug(&format!("POST: {} (chunk {}/{})", scan_upload_url, index + 1, total_chunks)); - let response = client.post(&scan_upload_url) - .header(header::CONTENT_TYPE, "application/json") - .header("Upload-Offset", offset.to_string()) - .header("Upload-Length", input_size.to_string()) - .body(chunk.to_vec()) - .send(); + let response = utils::api::retry_on_network_error("scan chunk upload", || { + client.post(&scan_upload_url) + .header(header::CONTENT_TYPE, "application/json") + .header("Upload-Offset", offset.to_string()) + .header("Upload-Length", input_size.to_string()) + .body(chunk.to_vec()) + .send() + }); let should_break = match &response { Ok(res) => { @@ -308,10 +312,12 @@ pub fn upload_scan(config: &Config, paths: Vec, scanner: String, input: last_response.expect("Failed to upload scan.") } else { debug(&format!("POST: {}", scan_upload_url)); - client.post(&scan_upload_url) - .header(header::CONTENT_TYPE, "application/json") - .body(input.clone()) - .send() + utils::api::retry_on_network_error("scan upload", || { + client.post(&scan_upload_url) + .header(header::CONTENT_TYPE, "application/json") + .body(input.clone()) + .send() + }) }; let mut sast_scan_id: Option = None; @@ -381,14 +387,15 @@ pub fn upload_scan(config: &Config, paths: Vec, scanner: String, input: if git_config_path.exists() { debug("Uploading .git/config"); - let form = reqwest::blocking::multipart::Form::new() - .file("file", git_config_path) - .expect("Failed to read file"); - debug(&format!("POST: {}", git_config_upload_url)); - let res = client.post(&git_config_upload_url) - .multipart(form) - .send(); + let res = utils::api::retry_on_network_error("git config upload", || { + let form = reqwest::blocking::multipart::Form::new() + .file("file", git_config_path) + .expect("Failed to read file"); + client.post(&git_config_upload_url) + .multipart(form) + .send() + }); match res { Ok(response) => { diff --git a/src/utils/api.rs b/src/utils/api.rs index f0e8a59..d008ad3 100644 --- a/src/utils/api.rs +++ b/src/utils/api.rs @@ -144,6 +144,34 @@ pub fn http_client() -> HttpClient { HttpClient { inner: SHARED_CLIENT.clone() } } +#[cfg(not(test))] +const RETRY_BACKOFF_SECS: &[u64] = &[1, 2, 4, 8, 16, 32]; + +#[cfg(test)] +const RETRY_BACKOFF_SECS: &[u64] = &[0, 0, 0, 0, 0, 0]; + +pub fn retry_on_network_error(operation: &str, mut make_request: F) -> reqwest::Result +where + F: FnMut() -> reqwest::Result, +{ + let mut attempt = 0usize; + loop { + match make_request() { + Ok(result) => return Ok(result), + Err(e) if (e.is_connect() || e.is_timeout()) && attempt < RETRY_BACKOFF_SECS.len() => { + let delay = RETRY_BACKOFF_SECS[attempt]; + eprintln!( + "Network error during {}: {}. Retrying in {}s... ({}/{})", + operation, e, delay, attempt + 1, RETRY_BACKOFF_SECS.len() + ); + std::thread::sleep(std::time::Duration::from_secs(delay)); + attempt += 1; + } + Err(e) => return Err(e), + } + } +} + fn check_for_warnings(headers: &HeaderMap, status: StatusCode) { if let Some(warning) = headers.get("warning") { let warnings = warning.to_str().unwrap().split(','); @@ -913,3 +941,133 @@ pub struct SCAIssuesResponse { pub total_pages: u32, pub total_issues: u32, } + +#[cfg(test)] +mod tests { + use super::*; + use std::cell::Cell; + use std::net::TcpListener; + use std::thread; + use std::time::Duration; + + fn connection_refused_error() -> reqwest::Error { + let listener = TcpListener::bind("127.0.0.1:0").expect("failed to bind ephemeral port"); + let port = listener.local_addr().expect("failed to get listener addr").port(); + drop(listener); + + reqwest::blocking::Client::builder() + .connect_timeout(Duration::from_secs(1)) + .build() + .expect("failed to build client") + .get(format!("http://127.0.0.1:{port}")) + .send() + .expect_err("expected connection error") + } + + fn timeout_error() -> reqwest::Error { + let listener = TcpListener::bind("127.0.0.1:0").expect("failed to bind ephemeral port"); + let port = listener.local_addr().expect("failed to get listener addr").port(); + + thread::spawn(move || { + if let Ok((_, _)) = listener.accept() { + thread::sleep(Duration::from_secs(30)); + } + }); + + reqwest::blocking::Client::builder() + .timeout(Duration::from_millis(200)) + .build() + .expect("failed to build client") + .get(format!("http://127.0.0.1:{port}")) + .send() + .expect_err("expected timeout error") + } + + fn non_retryable_error() -> reqwest::Error { + let err = reqwest::blocking::Client::new() + .get("http://[::1:bad") + .send() + .expect_err("expected request error"); + + assert!( + !err.is_connect() && !err.is_timeout(), + "expected a non-retryable reqwest error, got: {err}" + ); + err + } + + #[test] + fn retry_on_network_error_returns_ok_on_first_success() { + let attempts = Cell::new(0); + + let result = retry_on_network_error("test operation", || { + attempts.set(attempts.get() + 1); + Ok("success") + }); + + assert_eq!(result.unwrap(), "success"); + assert_eq!(attempts.get(), 1); + } + + #[test] + fn retry_on_network_error_retries_connect_errors_then_succeeds() { + let attempts = Cell::new(0); + + let result = retry_on_network_error("test operation", || { + let attempt = attempts.get() + 1; + attempts.set(attempt); + if attempt < 3 { + Err(connection_refused_error()) + } else { + Ok(42) + } + }); + + assert_eq!(result.unwrap(), 42); + assert_eq!(attempts.get(), 3); + } + + #[test] + fn retry_on_network_error_retries_timeout_errors() { + let attempts = Cell::new(0); + + let result = retry_on_network_error("test operation", || { + let attempt = attempts.get() + 1; + attempts.set(attempt); + if attempt == 1 { + Err(timeout_error()) + } else { + Ok("recovered") + } + }); + + assert_eq!(result.unwrap(), "recovered"); + assert_eq!(attempts.get(), 2); + } + + #[test] + fn retry_on_network_error_does_not_retry_non_network_errors() { + let attempts = Cell::new(0); + + let result: reqwest::Result<()> = retry_on_network_error("test operation", || { + attempts.set(attempts.get() + 1); + Err(non_retryable_error()) + }); + + assert!(result.is_err()); + assert_eq!(attempts.get(), 1); + } + + #[test] + fn retry_on_network_error_gives_up_after_max_retries() { + let attempts = Cell::new(0); + + let result: reqwest::Result<()> = retry_on_network_error("test operation", || { + attempts.set(attempts.get() + 1); + Err(connection_refused_error()) + }); + + assert!(result.is_err()); + assert_eq!(attempts.get(), RETRY_BACKOFF_SECS.len() + 1); + } +}