diff --git a/Cargo.lock b/Cargo.lock index 682a51c..c141243 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -160,6 +160,7 @@ dependencies = [ "crossterm", "dirs 5.0.1", "enclaveapp-app-storage", + "enclaveapp-core", "enclaveapp-wsl", "predicates", "regex", @@ -168,8 +169,10 @@ dependencies = [ "serde_json", "tempfile", "tokio", + "toml 0.8.23", "tracing", "tracing-subscriber", + "winresource", "zeroize", ] @@ -180,14 +183,16 @@ dependencies = [ "base64 0.22.1", "chrono", "dirs 5.0.1", + "enclaveapp-core", "regex", "reqwest", + "roxmltree", "serde", "serde_json", "tempfile", "thiserror 1.0.69", "tokio", - "toml", + "toml 0.8.23", "tracing", "url", "wiremock", @@ -631,11 +636,12 @@ name = "enclaveapp-core" version = "0.1.0" dependencies = [ "dirs 6.0.0", + "fs2", "libc", "serde", "serde_json", "thiserror 2.0.18", - "toml", + "toml 0.8.23", ] [[package]] @@ -772,6 +778,16 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fs2" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9564fc758e15025b46aa6643b1b77d047d1a56a1aea6e01002ac0c7026876213" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "futures" version = "0.3.32" @@ -1941,6 +1957,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "roxmltree" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c20b6793b5c2fa6553b250154b78d6d0db37e72700ae35fad9387a46f487c97" + [[package]] name = "rpassword" version = "7.4.0" @@ -2116,6 +2138,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_spanned" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6662b5879511e06e8999a8a235d848113e942c9124f211511b16466ee2995f26" +dependencies = [ + "serde_core", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -2433,11 +2464,26 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc1beb996b9d83529a9e75c17a1686767d148d70663143c7854d8b4a09ced362" dependencies = [ "serde", - "serde_spanned", - "toml_datetime", + "serde_spanned 0.6.9", + "toml_datetime 0.6.11", "toml_edit", ] +[[package]] +name = "toml" +version = "1.1.2+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81f3d15e84cbcd896376e6730314d59fb5a87f31e4b038454184435cd57defee" +dependencies = [ + "indexmap", + "serde_core", + "serde_spanned 1.1.1", + "toml_datetime 1.1.1+spec-1.1.0", + "toml_parser", + "toml_writer", + "winnow 1.0.1", +] + [[package]] name = "toml_datetime" version = "0.6.11" @@ -2447,6 +2493,15 @@ dependencies = [ "serde", ] +[[package]] +name = "toml_datetime" +version = "1.1.1+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3165f65f62e28e0115a00b2ebdd37eb6f3b641855f9d636d3cd4103767159ad7" +dependencies = [ + "serde_core", +] + [[package]] name = "toml_edit" version = "0.22.27" @@ -2455,10 +2510,19 @@ checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a" dependencies = [ "indexmap", "serde", - "serde_spanned", - "toml_datetime", + "serde_spanned 0.6.9", + "toml_datetime 0.6.11", "toml_write", - "winnow", + "winnow 0.7.15", +] + +[[package]] +name = "toml_parser" +version = "1.1.2+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2abe9b86193656635d2411dc43050282ca48aa31c2451210f4202550afb7526" +dependencies = [ + "winnow 1.0.1", ] [[package]] @@ -2467,6 +2531,12 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801" +[[package]] +name = "toml_writer" +version = "1.1.1+spec-1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "756daf9b1013ebe47a8776667b466417e2d4c5679d441c26230efd9ef78692db" + [[package]] name = "tower" version = "0.5.3" @@ -3227,6 +3297,22 @@ dependencies = [ "memchr", ] +[[package]] +name = "winnow" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09dac053f1cd375980747450bfc7250c264eaae0583872e845c0c7cd578872b5" + +[[package]] +name = "winresource" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0986a8b1d586b7d3e4fe3d9ea39fb451ae22869dcea4aa109d287a374d866087" +dependencies = [ + "toml 1.1.2+spec-1.1.0", + "version_check", +] + [[package]] name = "wiremock" version = "0.6.5" diff --git a/awsenc-cli/Cargo.toml b/awsenc-cli/Cargo.toml index ba6a367..6f5a5e4 100644 --- a/awsenc-cli/Cargo.toml +++ b/awsenc-cli/Cargo.toml @@ -13,6 +13,7 @@ workspace = true [dependencies] awsenc-core = { path = "../awsenc-core" } +enclaveapp-core = { workspace = true } enclaveapp-app-storage = { workspace = true } enclaveapp-wsl = { workspace = true } clap = { version = "4", features = ["derive", "env", "string"] } @@ -36,4 +37,5 @@ winresource = "0.1" assert_cmd = "2" predicates = "3" tempfile = "3" +toml = "0.8" enclaveapp-app-storage = { workspace = true, features = ["mock"] } diff --git a/awsenc-cli/src/auth.rs b/awsenc-cli/src/auth.rs index 4afe119..d347e82 100644 --- a/awsenc-cli/src/auth.rs +++ b/awsenc-cli/src/auth.rs @@ -2,6 +2,7 @@ use std::io::{IsTerminal, Read}; use std::time::Duration; use chrono::Utc; +use tracing::warn; use zeroize::Zeroizing; use awsenc_core::cache::{ @@ -53,8 +54,15 @@ pub async fn run_auth( }; let creds = obtain_credentials(&okta, &session_token, &resolved).await?; + let okta_session = match okta.create_session(&session_token).await { + Ok(session) => Some(session), + Err(err) => { + warn!("failed to create reusable Okta session: {err}"); + None + } + }; - encrypt_and_cache(profile, storage, &creds, &session_token)?; + encrypt_and_cache(profile, storage, &creds, okta_session.as_ref())?; usage::record_usage(profile); let remaining = creds @@ -104,6 +112,7 @@ async fn obtain_credentials( session_token: &Zeroizing, resolved: &config::ResolvedConfig, ) -> Result { + ensure_supported_secondary_role(resolved)?; eprintln!("Getting SAML assertion..."); let saml_assertion = okta .get_saml_assertion(session_token, &resolved.okta_application) @@ -136,33 +145,45 @@ async fn obtain_credentials( Ok(creds) } +fn ensure_supported_secondary_role(resolved: &config::ResolvedConfig) -> Result<()> { + if let Some(role) = resolved.secondary_role.as_deref() { + return Err(format!( + "secondary_role '{role}' is configured but chained role assumption is not supported yet" + ) + .into()); + } + Ok(()) +} + #[allow(clippy::cast_sign_loss)] fn encrypt_and_cache( profile: &str, storage: &dyn EncryptionStorage, creds: &awsenc_core::credential::AwsCredentials, - session_token: &Zeroizing, + okta_session: Option<&OktaSession>, ) -> Result<()> { let creds_json = serde_json::to_vec(creds)?; let aws_ciphertext = storage.encrypt(&creds_json)?; - let okta_session = OktaSession { - session_id: session_token.as_str().to_owned(), - expiration: Utc::now() + chrono::Duration::hours(2), - }; - let okta_json = serde_json::to_vec(&okta_session)?; - let okta_ciphertext = storage.encrypt(&okta_json)?; - let cache_file = CacheFile { header: CacheHeader { magic: MAGIC, version: FORMAT_VERSION, - flags: FLAG_HAS_OKTA_SESSION, + flags: if okta_session.is_some() { + FLAG_HAS_OKTA_SESSION + } else { + 0 + }, credential_expiration: creds.expiration.timestamp() as u64, - okta_session_expiration: okta_session.expiration.timestamp() as u64, + okta_session_expiration: okta_session + .map_or(0, |session| session.expiration.timestamp() as u64), }, aws_ciphertext, - okta_session_ciphertext: Some(okta_ciphertext), + okta_session_ciphertext: okta_session + .map(serde_json::to_vec) + .transpose()? + .map(|json| storage.encrypt(&json)) + .transpose()?, }; cache::write_cache(profile, &cache_file)?; @@ -290,17 +311,16 @@ mod tests { session_token: Zeroizing::new("sessiontoken456".to_string()), expiration: Utc::now() + chrono::Duration::hours(1), }; - let session_token = Zeroizing::new("okta-session-token".to_string()); + let okta_session = OktaSession { + session_id: "okta-session-id".to_string(), + expiration: Utc::now() + chrono::Duration::hours(2), + }; // Test encrypt_and_cache by verifying it constructs correct structures // without relying on file I/O (which depends on HOME env var) let creds_json = serde_json::to_vec(&creds).unwrap(); let aws_ciphertext = storage.encrypt(&creds_json).unwrap(); - let okta_session = OktaSession { - session_id: session_token.as_str().to_owned(), - expiration: Utc::now() + chrono::Duration::hours(2), - }; let okta_json = serde_json::to_vec(&okta_session).unwrap(); let okta_ciphertext = storage.encrypt(&okta_json).unwrap(); @@ -313,7 +333,7 @@ mod tests { // Verify Okta session can be decrypted let okta_plaintext = storage.decrypt(&okta_ciphertext).unwrap(); let recovered_session: OktaSession = serde_json::from_slice(&okta_plaintext).unwrap(); - assert_eq!(recovered_session.session_id, "okta-session-token"); + assert_eq!(recovered_session.session_id, "okta-session-id"); // Verify CacheFile structure #[allow(clippy::cast_sign_loss)] @@ -333,6 +353,36 @@ mod tests { assert!(cache_file.header.has_okta_session()); } + #[test] + fn encrypt_and_cache_without_okta_session_clears_session_flag() { + use enclaveapp_app_storage::mock::MockEncryptionStorage as MockStorage; + + let storage = MockStorage::new(); + let creds = awsenc_core::credential::AwsCredentials { + access_key_id: "AKIATEST".to_string(), + secret_access_key: Zeroizing::new("secretkey123".to_string()), + session_token: Zeroizing::new("sessiontoken456".to_string()), + expiration: Utc::now() + chrono::Duration::hours(1), + }; + + let creds_json = serde_json::to_vec(&creds).unwrap(); + let aws_ciphertext = storage.encrypt(&creds_json).unwrap(); + let cache_file = CacheFile { + header: CacheHeader { + magic: MAGIC, + version: FORMAT_VERSION, + flags: 0, + credential_expiration: creds.expiration.timestamp() as u64, + okta_session_expiration: 0, + }, + aws_ciphertext, + okta_session_ciphertext: None, + }; + + assert_eq!(cache_file.header.flags, 0); + assert!(!cache_file.header.has_okta_session()); + } + #[test] fn resolved_profile_positional_preferred() { let args = AuthArgs { @@ -358,4 +408,23 @@ mod tests { let args = default_auth_args(); assert_eq!(args.resolved_profile(), None); } + + #[test] + fn secondary_role_configuration_fails_fast() { + let resolved = config::ResolvedConfig { + okta_organization: "org.okta.com".into(), + okta_user: "user@example.com".into(), + okta_application: "https://org.okta.com/app".into(), + okta_role: "arn:aws:iam::123:role/Primary".into(), + okta_factor: "push".into(), + okta_duration: 3600, + biometric: false, + refresh_window_seconds: 600, + secondary_role: Some("arn:aws:iam::456:role/Secondary".into()), + region: None, + }; + + let err = ensure_supported_secondary_role(&resolved).unwrap_err(); + assert!(err.to_string().contains("secondary_role")); + } } diff --git a/awsenc-cli/src/exec.rs b/awsenc-cli/src/exec.rs index d842828..c70f948 100644 --- a/awsenc-cli/src/exec.rs +++ b/awsenc-cli/src/exec.rs @@ -30,6 +30,12 @@ pub async fn run_exec(args: &ExecArgs, storage: &dyn EncryptionStorage) -> Resul let creds = if let Some(c) = cached { c } else { + if !std::io::stdin().is_terminal() { + return Err( + "no cached credentials and stdin is not a TTY; run 'awsenc auth --pass-stdin' first" + .into(), + ); + } eprintln!("No cached credentials for '{profile}', authenticating..."); let auth_args = AuthArgs { profile_positional: Some(profile.to_owned()), @@ -77,7 +83,7 @@ pub async fn run_exec(args: &ExecArgs, storage: &dyn EncryptionStorage) -> Resul std::process::exit(status.code().unwrap_or(1)); } -fn resolve_exec_profile(args: &ExecArgs) -> Result { +pub(crate) fn resolve_exec_profile(args: &ExecArgs) -> Result { if let Some(p) = args.resolved_profile() { let global = config::load_global_config().unwrap_or_default(); return Ok(config::resolve_alias(p, &global)); @@ -143,33 +149,36 @@ mod tests { use super::*; use enclaveapp_app_storage::mock::MockEncryptionStorage as MockStorage; - // Mutex to serialize tests that modify the HOME env var - static HOME_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(()); - - fn setup_temp_home(tmp: &tempfile::TempDir) -> Option { + fn setup_temp_home(tmp: &tempfile::TempDir) -> (Option, Option) { let prev = std::env::var("HOME").ok(); + let prev_xdg = std::env::var("XDG_CONFIG_HOME").ok(); let config_dir = tmp.path().join(".config").join("awsenc"); std::fs::create_dir_all(&config_dir).unwrap(); std::env::set_var("HOME", tmp.path()); - prev + std::env::set_var("XDG_CONFIG_HOME", tmp.path().join(".config")); + (prev, prev_xdg) } - fn restore_home(prev: Option) { + fn restore_home(prev: Option, prev_xdg: Option) { match prev { Some(v) => std::env::set_var("HOME", v), None => std::env::remove_var("HOME"), } + match prev_xdg { + Some(v) => std::env::set_var("XDG_CONFIG_HOME", v), + None => std::env::remove_var("XDG_CONFIG_HOME"), + } } #[test] fn get_cached_credentials_returns_none_when_no_cache() { - let _lock = HOME_MUTEX.lock().expect("mutex poisoned"); + let _lock = crate::TEST_ENV_MUTEX.lock().expect("mutex poisoned"); let tmp = tempfile::tempdir().unwrap(); - let prev = setup_temp_home(&tmp); + let (prev, prev_xdg) = setup_temp_home(&tmp); let storage = MockStorage::new(); let result = get_cached_credentials("nonexistent-profile-xyz", &storage).unwrap(); assert!(result.is_none()); - restore_home(prev); + restore_home(prev, prev_xdg); } #[test] @@ -177,9 +186,9 @@ mod tests { use awsenc_core::cache::{self, CacheFile, CacheHeader, FORMAT_VERSION, MAGIC}; use zeroize::Zeroizing; - let _lock = HOME_MUTEX.lock().expect("mutex poisoned"); + let _lock = crate::TEST_ENV_MUTEX.lock().expect("mutex poisoned"); let tmp = tempfile::tempdir().unwrap(); - let prev = setup_temp_home(&tmp); + let (prev, prev_xdg) = setup_temp_home(&tmp); let storage = MockStorage::new(); let creds = AwsCredentials { @@ -214,7 +223,7 @@ mod tests { assert_eq!(recovered.access_key_id, "AKIATEST"); drop(cache::delete_cache(profile)); - restore_home(prev); + restore_home(prev, prev_xdg); } #[test] @@ -222,9 +231,9 @@ mod tests { use awsenc_core::cache::{self, CacheFile, CacheHeader, FORMAT_VERSION, MAGIC}; use zeroize::Zeroizing; - let _lock = HOME_MUTEX.lock().expect("mutex poisoned"); + let _lock = crate::TEST_ENV_MUTEX.lock().expect("mutex poisoned"); let tmp = tempfile::tempdir().unwrap(); - let prev = setup_temp_home(&tmp); + let (prev, prev_xdg) = setup_temp_home(&tmp); let storage = MockStorage::new(); let creds = AwsCredentials { @@ -260,7 +269,7 @@ mod tests { ); drop(cache::delete_cache(profile)); - restore_home(prev); + restore_home(prev, prev_xdg); } #[test] @@ -295,9 +304,11 @@ mod tests { #[test] fn get_profile_region_returns_none_for_nonexistent() { + let _lock = crate::TEST_ENV_MUTEX.lock().expect("mutex poisoned"); let tmp = tempfile::tempdir().unwrap(); - std::env::set_var("HOME", tmp.path()); + let (prev, prev_xdg) = setup_temp_home(&tmp); assert!(get_profile_region("nonexistent-profile").is_none()); + restore_home(prev, prev_xdg); } #[tokio::test] @@ -316,4 +327,18 @@ mod tests { "expected 'no command' error, got: {err}" ); } + + #[tokio::test] + async fn run_exec_without_cache_fails_fast_when_stdin_is_not_tty() { + let storage = MockStorage::new(); + let args = ExecArgs { + profile_positional: Some("test".to_string()), + profile_flag: None, + command: vec!["echo".to_string(), "hello".to_string()], + }; + let result = run_exec(&args, &storage).await; + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("stdin is not a TTY")); + } } diff --git a/awsenc-cli/src/install.rs b/awsenc-cli/src/install.rs index 80b228b..fb4c83e 100644 --- a/awsenc-cli/src/install.rs +++ b/awsenc-cli/src/install.rs @@ -1,6 +1,7 @@ use std::io::IsTerminal; use std::path::PathBuf; +use enclaveapp_core::metadata; use regex::Regex; use awsenc_core::cache; @@ -28,6 +29,7 @@ pub fn run_install(args: &InstallArgs) -> Result<()> { } } }; + config::validate_profile_name(&profile_name)?; // Build profile config from args let profile_config = ProfileConfig { @@ -38,6 +40,7 @@ pub fn run_install(args: &InstallArgs) -> Result<()> { factor: args.factor.clone(), duration: args.duration, }, + region: args.region.clone(), secondary_role: None, }; @@ -69,7 +72,7 @@ pub fn run_install(args: &InstallArgs) -> Result<()> { let managed_block = build_managed_block(&profile_name, binary_path_str, args.region.as_deref()); let updated = upsert_managed_block(&existing, &profile_name, &managed_block); - std::fs::write(&aws_config_path, &updated)?; + write_text_file(&aws_config_path, &updated)?; eprintln!("Updated {}", aws_config_path.display()); eprintln!("Installed profile '{profile_name}'"); @@ -85,13 +88,14 @@ pub fn run_uninstall(args: &UninstallArgs) -> Result<()> { .profile .as_deref() .ok_or("no profile specified; use --profile ")?; + config::validate_profile_name(profile_name)?; // Remove managed block from ~/.aws/config let aws_config_path = aws_config_path()?; if aws_config_path.exists() { let existing = std::fs::read_to_string(&aws_config_path)?; let updated = remove_managed_block(&existing, profile_name); - std::fs::write(&aws_config_path, &updated)?; + write_text_file(&aws_config_path, &updated)?; eprintln!("Removed managed block from {}", aws_config_path.display()); } @@ -177,6 +181,7 @@ pub fn run_migrate(args: &MigrateArgs) -> Result<()> { factor: entry.factor.clone(), duration: entry.duration, }, + region: entry.region.clone(), secondary_role: entry.secondary_role.as_ref().map(|r| SecondaryRoleConfig { role_arn: r.clone(), }), @@ -207,7 +212,7 @@ pub fn run_migrate(args: &MigrateArgs) -> Result<()> { updated = upsert_managed_block(&updated, profile_name, &block); } - std::fs::write(&aws_config_path, &updated)?; + write_text_file(&aws_config_path, &updated)?; eprintln!("\nMigrated {} profile(s)", migrated_profiles.len()); } @@ -229,13 +234,13 @@ fn aws_credentials_path() -> Result { } fn profile_config_path(name: &str) -> Result { - let dir = config::profiles_dir()?; - Ok(dir.join(format!("{name}.toml"))) + Ok(config::profile_config_path(name)?) } fn build_managed_block(profile_name: &str, binary_path: &str, region: Option<&str>) -> String { use std::fmt::Write; + let binary_path = quote_credential_process_arg(binary_path); let mut block = format!( "# --- BEGIN awsenc managed ({profile_name}) ---\n\ [profile {profile_name}]\n\ @@ -249,12 +254,9 @@ fn build_managed_block(profile_name: &str, binary_path: &str, region: Option<&st } fn upsert_managed_block(existing: &str, profile_name: &str, new_block: &str) -> String { - let begin = format!("# --- BEGIN awsenc managed ({profile_name}) ---"); - let end = format!("# --- END awsenc managed ({profile_name}) ---"); - - if let (Some(start), Some(end_pos)) = (existing.find(&begin), existing.find(&end)) { + if let Some((start, end)) = managed_block_range(existing, profile_name) { let before = &existing[..start]; - let after = &existing[end_pos + end.len()..]; + let after = &existing[end..]; format!("{before}{new_block}{after}") } else { // Append @@ -272,12 +274,9 @@ fn upsert_managed_block(existing: &str, profile_name: &str, new_block: &str) -> } fn remove_managed_block(existing: &str, profile_name: &str) -> String { - let begin = format!("# --- BEGIN awsenc managed ({profile_name}) ---"); - let end = format!("# --- END awsenc managed ({profile_name}) ---"); - - if let (Some(start), Some(end_pos)) = (existing.find(&begin), existing.find(&end)) { + if let Some((start, end)) = managed_block_range(existing, profile_name) { let before = &existing[..start]; - let after = &existing[end_pos + end.len()..]; + let after = &existing[end..]; // Clean up double newlines let after = after.trim_start_matches('\n'); let mut result = before.to_owned(); @@ -293,6 +292,34 @@ fn remove_managed_block(existing: &str, profile_name: &str) -> String { } } +fn managed_block_range(existing: &str, profile_name: &str) -> Option<(usize, usize)> { + let escaped = regex::escape(profile_name); + let pattern = format!( + r"(?ms)^# --- BEGIN awsenc managed \({escaped}\) ---\n.*?^# --- END awsenc managed \({escaped}\) ---\n?" + ); + let re = Regex::new(&pattern).ok()?; + re.find(existing).map(|m| (m.start(), m.end())) +} + +fn quote_credential_process_arg(arg: &str) -> String { + if !arg + .chars() + .any(|c| c.is_whitespace() || c == '"' || c == '\\') + { + return arg.to_owned(); + } + + format!("\"{}\"", arg.replace('\\', "\\\\").replace('"', "\\\"")) +} + +fn write_text_file(path: &std::path::Path, contents: &str) -> Result<()> { + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent)?; + } + metadata::atomic_write(path, contents.as_bytes()) + .map_err(|e| format!("failed to write {}: {e}", path.display()).into()) +} + /// Parsed aws-okta-processor entry from an AWS config file. struct OktaProcessorEntry { profile_name: String, @@ -301,6 +328,7 @@ struct OktaProcessorEntry { role: Option, factor: Option, duration: Option, + region: Option, secondary_role: Option, } @@ -309,6 +337,7 @@ fn find_okta_processor_entries(content: &str) -> Vec { let mut entries = Vec::new(); let mut current_profile: Option = None; let mut current_cred_process: Option = None; + let mut current_region: Option = None; let profile_re = Regex::new(r"^\[(?:profile\s+)?([^\]]+)\]").expect("valid regex"); @@ -319,29 +348,36 @@ fn find_okta_processor_entries(content: &str) -> Vec { // Flush previous profile if let (Some(name), Some(cp)) = (current_profile.take(), current_cred_process.take()) { if cp.contains("aws-okta-processor") || cp.contains("aws_okta_processor") { - entries.push(parse_okta_processor_line(&name, &cp)); + entries.push(parse_okta_processor_line(&name, &cp, current_region.take())); } } current_profile = Some(caps[1].to_owned()); current_cred_process = None; + current_region = None; } else if trimmed.starts_with("credential_process") { if let Some(value) = trimmed.split_once('=').map(|(_, v)| v.trim()) { current_cred_process = Some(value.to_owned()); } + } else if trimmed.starts_with("region") { + current_region = trimmed.split_once('=').map(|(_, v)| v.trim().to_owned()); } } // Flush last profile if let (Some(name), Some(cp)) = (current_profile, current_cred_process) { if cp.contains("aws-okta-processor") || cp.contains("aws_okta_processor") { - entries.push(parse_okta_processor_line(&name, &cp)); + entries.push(parse_okta_processor_line(&name, &cp, current_region)); } } entries } -fn parse_okta_processor_line(profile_name: &str, command_line: &str) -> OktaProcessorEntry { +fn parse_okta_processor_line( + profile_name: &str, + command_line: &str, + region: Option, +) -> OktaProcessorEntry { fn extract_flag(line: &str, flag: &str) -> Option { let patterns = [format!("--{flag} "), format!("--{flag}=")]; for pat in &patterns { @@ -368,6 +404,7 @@ fn parse_okta_processor_line(profile_name: &str, command_line: &str) -> OktaProc role: extract_flag(command_line, "role"), factor: extract_flag(command_line, "factor"), duration: extract_flag(command_line, "duration").and_then(|d| d.parse().ok()), + region, secondary_role: extract_flag(command_line, "secondary-role"), } } @@ -392,6 +429,14 @@ mod tests { assert!(block.contains("region = us-west-2")); } + #[test] + fn build_managed_block_quotes_binary_with_spaces() { + let block = build_managed_block("prod", "/Applications/Aws Enc/awsenc", None); + assert!(block.contains( + "credential_process = \"/Applications/Aws Enc/awsenc\" serve --profile prod" + )); + } + #[test] fn upsert_managed_block_append() { let existing = "[profile other]\nsome = value\n"; @@ -446,6 +491,7 @@ region = us-west-2 ); assert_eq!(entries[0].factor.as_deref(), Some("push")); assert_eq!(entries[0].duration, Some(3600)); + assert_eq!(entries[0].region.as_deref(), Some("us-west-2")); } #[test] @@ -478,11 +524,13 @@ credential_process = aws-okta-processor authenticate --organization org2.okta.co assert_eq!(entries[0].organization.as_deref(), Some("org1.okta.com")); assert_eq!(entries[0].factor.as_deref(), Some("push")); assert!(entries[0].duration.is_none()); + assert_eq!(entries[0].region.as_deref(), Some("us-east-1")); assert_eq!(entries[1].profile_name, "account2"); assert_eq!(entries[1].organization.as_deref(), Some("org2.okta.com")); assert_eq!(entries[1].factor.as_deref(), Some("totp")); assert_eq!(entries[1].duration, Some(7200)); + assert!(entries[1].region.is_none()); } #[test] @@ -557,8 +605,10 @@ credential_process = aws-okta-processor authenticate --organization "my org.okta let entry = parse_okta_processor_line( "test", "aws-okta-processor authenticate --organization=org.okta.com --factor=push", + Some("us-west-1".into()), ); assert_eq!(entry.organization.as_deref(), Some("org.okta.com")); assert_eq!(entry.factor.as_deref(), Some("push")); + assert_eq!(entry.region.as_deref(), Some("us-west-1")); } } diff --git a/awsenc-cli/src/main.rs b/awsenc-cli/src/main.rs index ecd1485..60677a2 100644 --- a/awsenc-cli/src/main.rs +++ b/awsenc-cli/src/main.rs @@ -22,6 +22,9 @@ mod usage; use cli::{Cli, Commands}; +#[cfg(test)] +pub(crate) static TEST_ENV_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(()); + #[tokio::main] #[allow(clippy::print_stderr)] async fn main() { @@ -154,31 +157,15 @@ fn resolve_biometric_for_profile(profile: &str, cli_biometric: bool) -> bool { } fn resolve_biometric_from_serve(args: &cli::ServeArgs) -> bool { - let profile = args - .profile - .clone() - .or_else(|| std::env::var("AWSENC_PROFILE").ok()) - .unwrap_or_default(); - - if profile.is_empty() { - return false; - } - - resolve_biometric_for_profile(&profile, false) + serve::resolve_serve_profile(args) + .map(|profile| resolve_biometric_for_profile(&profile, false)) + .unwrap_or(false) } fn resolve_biometric_from_exec(args: &cli::ExecArgs) -> bool { - let profile = args - .resolved_profile() - .map(str::to_owned) - .or_else(|| std::env::var("AWSENC_PROFILE").ok()) - .unwrap_or_default(); - - if profile.is_empty() { - return false; - } - - resolve_biometric_for_profile(&profile, false) + exec::resolve_exec_profile(args) + .map(|profile| resolve_biometric_for_profile(&profile, false)) + .unwrap_or(false) } #[allow(clippy::print_stderr, clippy::print_stdout)] @@ -429,22 +416,29 @@ mod tests { #[test] fn resolve_biometric_for_nonexistent_profile() { + let _lock = TEST_ENV_MUTEX.lock().expect("mutex poisoned"); let tmp = tempfile::tempdir().unwrap(); - std::env::set_var("HOME", tmp.path()); + let (prev_home, prev_userprofile, prev_xdg) = set_test_home(tmp.path()); let result = resolve_biometric_for_profile("nonexistent-profile-xyz", false); assert!(!result); + restore_test_home(prev_home, prev_userprofile, prev_xdg); } #[test] fn resolve_biometric_for_profile_cli_override() { + let _lock = TEST_ENV_MUTEX.lock().expect("mutex poisoned"); let tmp = tempfile::tempdir().unwrap(); - std::env::set_var("HOME", tmp.path()); + let (prev_home, prev_userprofile, prev_xdg) = set_test_home(tmp.path()); let result = resolve_biometric_for_profile("nonexistent-profile-xyz", true); assert!(result); + restore_test_home(prev_home, prev_userprofile, prev_xdg); } #[test] fn resolve_biometric_from_serve_empty_profile() { + let _lock = TEST_ENV_MUTEX.lock().expect("mutex poisoned"); + let tmp = tempfile::tempdir().unwrap(); + let (prev_home, prev_userprofile, prev_xdg) = set_test_home(tmp.path()); let prev = std::env::var("AWSENC_PROFILE").ok(); std::env::remove_var("AWSENC_PROFILE"); @@ -458,10 +452,14 @@ mod tests { if let Some(v) = prev { std::env::set_var("AWSENC_PROFILE", v); } + restore_test_home(prev_home, prev_userprofile, prev_xdg); } #[test] fn resolve_biometric_from_exec_empty_profile() { + let _lock = TEST_ENV_MUTEX.lock().expect("mutex poisoned"); + let tmp = tempfile::tempdir().unwrap(); + let (prev_home, prev_userprofile, prev_xdg) = set_test_home(tmp.path()); let prev = std::env::var("AWSENC_PROFILE").ok(); std::env::remove_var("AWSENC_PROFILE"); @@ -476,12 +474,14 @@ mod tests { if let Some(v) = prev { std::env::set_var("AWSENC_PROFILE", v); } + restore_test_home(prev_home, prev_userprofile, prev_xdg); } #[test] fn resolve_biometric_from_exec_with_profile() { + let _lock = TEST_ENV_MUTEX.lock().expect("mutex poisoned"); let tmp = tempfile::tempdir().unwrap(); - std::env::set_var("HOME", tmp.path()); + let (prev_home, prev_userprofile, prev_xdg) = set_test_home(tmp.path()); let args = cli::ExecArgs { profile_positional: Some("some-profile".to_string()), @@ -490,5 +490,83 @@ mod tests { }; let result = resolve_biometric_from_exec(&args); assert!(!result); + restore_test_home(prev_home, prev_userprofile, prev_xdg); + } + + #[test] + fn resolve_biometric_from_exec_uses_resolved_alias() { + let _lock = TEST_ENV_MUTEX.lock().expect("mutex poisoned"); + let tmp = tempfile::tempdir().unwrap(); + let (prev_home, prev_userprofile, prev_xdg) = set_test_home(tmp.path()); + + let mut global = config::GlobalConfig::default(); + global.okta.user = Some("tester@example.com".into()); + global.aliases.insert("prod".into(), "real-profile".into()); + let global_toml = toml::to_string_pretty(&global).unwrap(); + let config_dir = tmp.path().join(".config").join("awsenc"); + std::fs::create_dir_all(&config_dir).unwrap(); + std::fs::write(config_dir.join("config.toml"), global_toml).unwrap(); + + let profile = config::ProfileConfig { + okta: config::ProfileOktaConfig { + organization: Some("org.okta.com".into()), + application: Some("https://org.okta.com/app".into()), + role: Some("arn:aws:iam::123:role/R".into()), + factor: None, + duration: None, + }, + region: None, + secondary_role: None, + }; + config::save_profile_config("real-profile", &profile).unwrap(); + + let prev_bio = std::env::var("AWSENC_BIOMETRIC").ok(); + let prev_user = std::env::var("AWSENC_OKTA_USER").ok(); + std::env::set_var("AWSENC_BIOMETRIC", "true"); + std::env::set_var("AWSENC_OKTA_USER", "tester@example.com"); + let args = cli::ExecArgs { + profile_positional: Some("prod".to_string()), + profile_flag: None, + command: vec!["echo".to_string()], + }; + assert!(resolve_biometric_from_exec(&args)); + match prev_bio { + Some(v) => std::env::set_var("AWSENC_BIOMETRIC", v), + None => std::env::remove_var("AWSENC_BIOMETRIC"), + } + match prev_user { + Some(v) => std::env::set_var("AWSENC_OKTA_USER", v), + None => std::env::remove_var("AWSENC_OKTA_USER"), + } + restore_test_home(prev_home, prev_userprofile, prev_xdg); + } + + fn set_test_home(path: &std::path::Path) -> (Option, Option, Option) { + let prev_home = std::env::var("HOME").ok(); + let prev_userprofile = std::env::var("USERPROFILE").ok(); + let prev_xdg = std::env::var("XDG_CONFIG_HOME").ok(); + std::env::set_var("HOME", path); + std::env::set_var("USERPROFILE", path); + std::env::set_var("XDG_CONFIG_HOME", path.join(".config")); + (prev_home, prev_userprofile, prev_xdg) + } + + fn restore_test_home( + prev_home: Option, + prev_userprofile: Option, + prev_xdg: Option, + ) { + match prev_home { + Some(value) => std::env::set_var("HOME", value), + None => std::env::remove_var("HOME"), + } + match prev_userprofile { + Some(value) => std::env::set_var("USERPROFILE", value), + None => std::env::remove_var("USERPROFILE"), + } + match prev_xdg { + Some(value) => std::env::set_var("XDG_CONFIG_HOME", value), + None => std::env::remove_var("XDG_CONFIG_HOME"), + } } } diff --git a/awsenc-cli/src/serve.rs b/awsenc-cli/src/serve.rs index 98dd2d9..7990bdf 100644 --- a/awsenc-cli/src/serve.rs +++ b/awsenc-cli/src/serve.rs @@ -42,18 +42,18 @@ pub async fn run_serve(args: &ServeArgs, storage: &dyn EncryptionStorage) -> Res match state { CredentialState::Fresh => { let creds = decrypt_aws_credentials(storage, &cache.aws_ciphertext)?; - output_credentials(&creds); + print_credentials(&creds)?; usage::record_usage(&profile); } CredentialState::Refresh => { match try_transparent_reauth(&profile, storage, &cache).await { Ok(new_creds) => { - output_credentials(&new_creds); + print_credentials(&new_creds)?; } Err(e) => { tracing::debug!("transparent re-auth failed: {e}; using cached credentials"); let creds = decrypt_aws_credentials(storage, &cache.aws_ciphertext)?; - output_credentials(&creds); + print_credentials(&creds)?; } } usage::record_usage(&profile); @@ -61,7 +61,7 @@ pub async fn run_serve(args: &ServeArgs, storage: &dyn EncryptionStorage) -> Res CredentialState::Expired => { let reauth_result = try_transparent_reauth(&profile, storage, &cache).await; if let Ok(new_creds) = reauth_result { - output_credentials(&new_creds); + print_credentials(&new_creds)?; usage::record_usage(&profile); } else { eprintln!("Credentials for profile '{profile}' are expired"); @@ -74,7 +74,7 @@ pub async fn run_serve(args: &ServeArgs, storage: &dyn EncryptionStorage) -> Res Ok(()) } -fn resolve_serve_profile(args: &ServeArgs) -> Result { +pub(crate) fn resolve_serve_profile(args: &ServeArgs) -> Result { if let Some(ref p) = args.profile { let global = config::load_global_config().unwrap_or_default(); return Ok(config::resolve_alias(p, &global)); @@ -113,11 +113,13 @@ fn decrypt_aws_credentials( } #[allow(clippy::print_stdout)] -fn output_credentials(creds: &AwsCredentials) { +fn print_credentials(creds: &AwsCredentials) -> Result<()> { let output = CredentialProcessOutput::from_credentials(creds); // This is the ONLY thing that goes to stdout - let json = serde_json::to_string(&output).expect("credential JSON serialization failed"); + let json = serde_json::to_string(&output) + .map_err(|e| format!("credential JSON serialization failed: {e}"))?; println!("{json}"); + Ok(()) } /// Attempt transparent re-authentication using a cached Okta session. @@ -151,6 +153,12 @@ async fn try_transparent_reauth( let profile_config = config::load_profile_config(profile)?; let overrides = ConfigOverrides::from_env(); let resolved = config::resolve_config(profile, &global, &profile_config, &overrides)?; + if let Some(role) = resolved.secondary_role.as_deref() { + return Err(format!( + "secondary_role '{role}' is configured but chained role assumption is not supported yet" + ) + .into()); + } let okta = OktaClient::new(&resolved.okta_organization)?; let saml_assertion = okta @@ -212,6 +220,7 @@ mod tests { #[test] fn resolve_serve_profile_no_profile_no_active_no_env() { + let _lock = crate::TEST_ENV_MUTEX.lock().expect("mutex poisoned"); let prev = std::env::var("AWSENC_PROFILE").ok(); std::env::remove_var("AWSENC_PROFILE"); @@ -234,6 +243,7 @@ mod tests { #[test] fn resolve_serve_profile_active_flag_without_env() { + let _lock = crate::TEST_ENV_MUTEX.lock().expect("mutex poisoned"); let prev = std::env::var("AWSENC_PROFILE").ok(); std::env::remove_var("AWSENC_PROFILE"); @@ -292,8 +302,8 @@ mod tests { session_token: Zeroizing::new("token".to_string()), expiration: Utc::now(), }; - let output = CredentialProcessOutput::from_credentials(&creds); - let json = serde_json::to_string(&output).unwrap(); + let json = + serde_json::to_string(&CredentialProcessOutput::from_credentials(&creds)).unwrap(); assert!(json.contains("AKIDTEST")); assert!(json.contains("Version")); } diff --git a/awsenc-core/Cargo.toml b/awsenc-core/Cargo.toml index 1652d39..4dcad79 100644 --- a/awsenc-core/Cargo.toml +++ b/awsenc-core/Cargo.toml @@ -8,6 +8,7 @@ rust-version.workspace = true workspace = true [dependencies] +enclaveapp-core = { workspace = true } serde = { version = "1", features = ["derive"] } serde_json = "1" toml = "0.8" @@ -21,6 +22,7 @@ zeroize = { version = "1", features = ["derive"] } url = "2" regex = "1" tokio = { version = "1", features = ["time", "macros"] } +roxmltree = "0.20" [dev-dependencies] tempfile = "3" diff --git a/awsenc-core/src/cache.rs b/awsenc-core/src/cache.rs index 7f073da..1cdc8e5 100644 --- a/awsenc-core/src/cache.rs +++ b/awsenc-core/src/cache.rs @@ -1,5 +1,7 @@ use std::path::PathBuf; +use enclaveapp_core::metadata; + use crate::{Error, Result}; /// Magic bytes: "AWSE" @@ -184,24 +186,14 @@ pub fn read_cache(profile: &str) -> Result> { /// Sets 0o600 permissions on Unix. pub fn write_cache(profile: &str, cache: &CacheFile) -> Result<()> { let path = cache_path(profile)?; - let dir = path - .parent() - .ok_or_else(|| Error::CacheFormat("cache path has no parent directory".into()))?; - let encoded = cache.encode(); - - // Write to a temp file in the same directory, then rename for atomicity. - let sanitized = sanitize_profile_name(profile)?; - let temp_path = dir.join(format!(".{sanitized}.enc.tmp")); - std::fs::write(&temp_path, &encoded)?; - + metadata::atomic_write(&path, &encoded) + .map_err(|e| Error::CacheFormat(format!("failed to write cache: {e}")))?; #[cfg(unix)] { - use std::os::unix::fs::PermissionsExt; - std::fs::set_permissions(&temp_path, std::fs::Permissions::from_mode(0o600))?; + metadata::restrict_file_permissions(&path) + .map_err(|e| Error::CacheFormat(format!("failed to secure cache: {e}")))?; } - - std::fs::rename(&temp_path, &path)?; Ok(()) } @@ -258,25 +250,7 @@ pub fn delete_cache(profile: &str) -> Result<()> { /// Validate a profile name: alphanumeric, hyphens, underscores only, max 64 characters. pub fn sanitize_profile_name(name: &str) -> Result { - if name.is_empty() { - return Err(Error::InvalidProfileName( - "profile name cannot be empty".into(), - )); - } - if name.len() > 64 { - return Err(Error::InvalidProfileName(format!( - "profile name exceeds 64 characters: {name}" - ))); - } - if !name - .chars() - .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_') - { - return Err(Error::InvalidProfileName(format!( - "profile name contains invalid characters (only alphanumeric, hyphens, underscores allowed): {name}" - ))); - } - Ok(name.to_owned()) + Ok(crate::config::validate_profile_name(name)?.to_owned()) } #[cfg(test)] @@ -414,6 +388,43 @@ mod tests { assert!(!header2.has_okta_session()); } + #[test] + fn write_cache_ignores_preexisting_legacy_tmp_file() { + let _lock = crate::TEST_ENV_MUTEX.lock().expect("mutex poisoned"); + let dir = tempfile::tempdir().unwrap(); + let prev_home = std::env::var("HOME").ok(); + let prev_xdg = std::env::var("XDG_CONFIG_HOME").ok(); + std::env::set_var("HOME", dir.path()); + std::env::set_var("XDG_CONFIG_HOME", dir.path().join(".config")); + let path = cache_path("tmp-test").unwrap(); + let temp_path = path.parent().unwrap().join(".tmp-test.enc.tmp"); + std::fs::write(&temp_path, b"stale").unwrap(); + + let cache = CacheFile { + header: CacheHeader { + magic: MAGIC, + version: FORMAT_VERSION, + flags: 0, + credential_expiration: 1, + okta_session_expiration: 0, + }, + aws_ciphertext: vec![1, 2, 3], + okta_session_ciphertext: None, + }; + + write_cache("tmp-test", &cache).unwrap(); + let loaded = read_cache("tmp-test").unwrap().unwrap(); + assert_eq!(loaded.aws_ciphertext, vec![1, 2, 3]); + match prev_home { + Some(v) => std::env::set_var("HOME", v), + None => std::env::remove_var("HOME"), + } + match prev_xdg { + Some(v) => std::env::set_var("XDG_CONFIG_HOME", v), + None => std::env::remove_var("XDG_CONFIG_HOME"), + } + } + #[test] fn encode_large_ciphertext() { let big = vec![0xAB; 100_000]; diff --git a/awsenc-core/src/config.rs b/awsenc-core/src/config.rs index 0ddb7b4..500ca83 100644 --- a/awsenc-core/src/config.rs +++ b/awsenc-core/src/config.rs @@ -1,8 +1,10 @@ use std::collections::HashMap; use std::path::PathBuf; +use enclaveapp_core::metadata; use serde::{Deserialize, Serialize}; +use crate::okta::{validate_okta_application_url, validate_okta_organization}; use crate::{Error, Result}; // --------------------------------------------------------------------------- @@ -46,6 +48,7 @@ pub struct CacheConfig { pub struct ProfileConfig { #[serde(default)] pub okta: ProfileOktaConfig, + pub region: Option, pub secondary_role: Option, } @@ -92,7 +95,10 @@ impl ConfigOverrides { biometric: std::env::var("AWSENC_BIOMETRIC") .ok() .and_then(|v| v.parse::().ok()), - region: None, + region: std::env::var("AWSENC_REGION") + .ok() + .or_else(|| std::env::var("AWS_REGION").ok()) + .or_else(|| std::env::var("AWS_DEFAULT_REGION").ok()), } } } @@ -119,12 +125,25 @@ pub struct ResolvedConfig { // Directory helpers // --------------------------------------------------------------------------- +fn config_root_dir() -> Result { + if let Some(dir) = std::env::var_os("XDG_CONFIG_HOME") { + let path = PathBuf::from(dir); + if !path.is_absolute() { + return Err(Error::Config( + "XDG_CONFIG_HOME must be an absolute path".into(), + )); + } + return Ok(path); + } + + Ok(dirs::home_dir() + .ok_or_else(|| Error::Config("could not determine home directory".into()))? + .join(".config")) +} + /// Returns `~/.config/awsenc/`, creating it with 0o700 permissions if necessary. pub fn config_dir() -> Result { - let dir = dirs::home_dir() - .ok_or_else(|| Error::Config("could not determine home directory".into()))? - .join(".config") - .join("awsenc"); + let dir = config_root_dir()?.join("awsenc"); ensure_dir(&dir)?; Ok(dir) } @@ -166,7 +185,7 @@ pub fn load_global_config() -> Result { /// Load a profile config from `~/.config/awsenc/profiles/.toml`. pub fn load_profile_config(name: &str) -> Result { - let path = profiles_dir()?.join(format!("{name}.toml")); + let path = profile_config_path(name)?; if !path.exists() { return Err(Error::Config(format!("profile config not found: {name}"))); } @@ -177,14 +196,51 @@ pub fn load_profile_config(name: &str) -> Result { /// Save a profile config to `~/.config/awsenc/profiles/.toml`. pub fn save_profile_config(name: &str, config: &ProfileConfig) -> Result<()> { - let dir = profiles_dir()?; - let path = dir.join(format!("{name}.toml")); + let path = profile_config_path(name)?; let contents = toml::to_string_pretty(config)?; - std::fs::write(&path, contents)?; + write_private_file(&path, contents.as_bytes())?; + Ok(()) +} + +/// Validate a profile name before it is used as a filesystem path component. +pub fn validate_profile_name(name: &str) -> Result<&str> { + if name.is_empty() { + return Err(Error::InvalidProfileName( + "profile name cannot be empty".into(), + )); + } + if name.len() > 64 { + return Err(Error::InvalidProfileName(format!( + "profile name exceeds 64 characters: {name}" + ))); + } + if !name + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_') + { + return Err(Error::InvalidProfileName(format!( + "profile name contains invalid characters (only alphanumeric, hyphens, underscores allowed): {name}" + ))); + } + Ok(name) +} + +/// Resolve the on-disk path for a profile config after validating the name. +pub fn profile_config_path(name: &str) -> Result { + let name = validate_profile_name(name)?; + Ok(profiles_dir()?.join(format!("{name}.toml"))) +} + +fn write_private_file(path: &std::path::Path, contents: &[u8]) -> Result<()> { + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent)?; + } + metadata::atomic_write(path, contents) + .map_err(|e| Error::Config(format!("failed to write {}: {e}", path.display())))?; #[cfg(unix)] { - use std::os::unix::fs::PermissionsExt; - std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o600))?; + metadata::restrict_file_permissions(path) + .map_err(|e| Error::Config(format!("failed to secure {}: {e}", path.display())))?; } Ok(()) } @@ -210,6 +266,7 @@ pub fn resolve_config( .or_else(|| profile.okta.organization.clone()) .or_else(|| global.okta.organization.clone()) .ok_or_else(|| Error::MissingConfig("okta organization".into()))?; + let okta_organization = validate_okta_organization(&okta_organization)?; let okta_user = overrides .user @@ -222,6 +279,7 @@ pub fn resolve_config( .clone() .or_else(|| profile.okta.application.clone()) .ok_or_else(|| Error::MissingConfig("okta application URL".into()))?; + let okta_application = validate_okta_application_url(&okta_organization, &okta_application)?; let okta_role = overrides .role @@ -250,7 +308,7 @@ pub fn resolve_config( .as_ref() .map(|sr| sr.role_arn.clone()); - let region = overrides.region.clone(); + let region = overrides.region.clone().or_else(|| profile.region.clone()); Ok(ResolvedConfig { okta_organization, @@ -283,6 +341,7 @@ mod tests { #[test] fn config_overrides_from_env_empty() { + let _lock = crate::TEST_ENV_MUTEX.lock().expect("mutex poisoned"); // Clear relevant env vars to test defaults std::env::remove_var("AWSENC_OKTA_USER"); std::env::remove_var("AWSENC_OKTA_ORG"); @@ -333,11 +392,12 @@ mod tests { let profile = ProfileConfig { okta: ProfileOktaConfig { organization: None, - application: Some("https://org.okta.com/home/amazon_aws/0oa123/272".into()), + application: Some("https://global-org.okta.com/home/amazon_aws/0oa123/272".into()), role: Some("arn:aws:iam::123456789012:role/MyRole".into()), factor: Some("yubikey".into()), duration: Some(7200), }, + region: Some("us-west-2".into()), secondary_role: None, }; @@ -348,6 +408,7 @@ mod tests { assert_eq!(resolved.okta_user, "globaluser"); assert_eq!(resolved.okta_factor, "yubikey"); // profile overrides global assert_eq!(resolved.okta_duration, 7200); + assert_eq!(resolved.region.as_deref(), Some("us-west-2")); assert!(resolved.biometric); assert_eq!(resolved.refresh_window_seconds, 300); } @@ -366,23 +427,26 @@ mod tests { let profile = ProfileConfig { okta: ProfileOktaConfig { organization: None, - application: Some("https://org.okta.com/app".into()), + application: Some("https://global-org.okta.com/app".into()), role: Some("arn:aws:iam::123:role/R".into()), factor: Some("yubikey".into()), duration: None, }, + region: None, secondary_role: None, }; let overrides = ConfigOverrides { factor: Some("totp".into()), duration: Some(900), + region: Some("eu-west-1".into()), ..Default::default() }; let resolved = resolve_config("test", &global, &profile, &overrides).unwrap(); assert_eq!(resolved.okta_factor, "totp"); // override beats profile assert_eq!(resolved.okta_duration, 900); + assert_eq!(resolved.region.as_deref(), Some("eu-west-1")); } #[test] @@ -395,6 +459,86 @@ mod tests { assert!(result.is_err()); } + #[test] + fn resolve_config_rejects_cross_origin_okta_application() { + let global = GlobalConfig { + okta: OktaConfig { + organization: Some("global-org.okta.com".into()), + user: Some("globaluser".into()), + default_factor: None, + }, + ..Default::default() + }; + let profile = ProfileConfig { + okta: ProfileOktaConfig { + organization: None, + application: Some("https://evil.example.com/home/app".into()), + role: Some("arn:aws:iam::123:role/R".into()), + factor: None, + duration: None, + }, + ..Default::default() + }; + + let error = + resolve_config("test", &global, &profile, &ConfigOverrides::default()).unwrap_err(); + assert!(error + .to_string() + .contains("must match Okta organization origin")); + } + + #[test] + fn resolve_config_rejects_cleartext_okta_application() { + let global = GlobalConfig { + okta: OktaConfig { + organization: Some("global-org.okta.com".into()), + user: Some("globaluser".into()), + default_factor: None, + }, + ..Default::default() + }; + let profile = ProfileConfig { + okta: ProfileOktaConfig { + organization: None, + application: Some("http://global-org.okta.com/home/app".into()), + role: Some("arn:aws:iam::123:role/R".into()), + factor: None, + duration: None, + }, + ..Default::default() + }; + + let error = + resolve_config("test", &global, &profile, &ConfigOverrides::default()).unwrap_err(); + assert!(error.to_string().contains("must use HTTPS")); + } + + #[test] + fn resolve_config_rejects_non_host_okta_organization() { + let global = GlobalConfig { + okta: OktaConfig { + organization: Some("global-org.okta.com/path".into()), + user: Some("globaluser".into()), + default_factor: None, + }, + ..Default::default() + }; + let profile = ProfileConfig { + okta: ProfileOktaConfig { + organization: None, + application: Some("https://global-org.okta.com/home/app".into()), + role: Some("arn:aws:iam::123:role/R".into()), + factor: None, + duration: None, + }, + ..Default::default() + }; + + let error = + resolve_config("test", &global, &profile, &ConfigOverrides::default()).unwrap_err(); + assert!(error.to_string().contains("bare host")); + } + #[test] fn global_config_roundtrip_toml() { let mut config = GlobalConfig::default(); @@ -421,6 +565,7 @@ mod tests { factor: None, duration: Some(3600), }, + region: Some("us-east-1".into()), secondary_role: Some(SecondaryRoleConfig { role_arn: "arn:aws:iam::456:role/S".into(), }), @@ -432,6 +577,7 @@ mod tests { parsed.okta.application.as_deref(), Some("https://org.okta.com/app") ); + assert_eq!(parsed.region.as_deref(), Some("us-east-1")); assert_eq!( parsed .secondary_role @@ -443,14 +589,67 @@ mod tests { #[test] fn config_dir_returns_path() { + let _lock = crate::TEST_ENV_MUTEX.lock().expect("mutex poisoned"); // Just verify it produces a path without error let dir = config_dir().unwrap(); assert!(dir.ends_with("awsenc")); } + #[test] + fn config_dir_uses_xdg_config_home_when_set() { + let _lock = crate::TEST_ENV_MUTEX.lock().expect("mutex poisoned"); + let tmp = tempfile::tempdir().unwrap(); + let xdg_dir = tmp.path().join("xdg"); + let prev_xdg = std::env::var("XDG_CONFIG_HOME").ok(); + std::env::set_var("XDG_CONFIG_HOME", &xdg_dir); + + let dir = config_dir().unwrap(); + assert_eq!(dir, xdg_dir.join("awsenc")); + + match prev_xdg { + Some(value) => std::env::set_var("XDG_CONFIG_HOME", value), + None => std::env::remove_var("XDG_CONFIG_HOME"), + } + } + + #[test] + fn config_dir_rejects_relative_xdg_config_home() { + let _lock = crate::TEST_ENV_MUTEX.lock().expect("mutex poisoned"); + let prev_xdg = std::env::var("XDG_CONFIG_HOME").ok(); + std::env::set_var("XDG_CONFIG_HOME", "relative-config"); + + let error = config_dir().unwrap_err().to_string(); + assert!(error.contains("XDG_CONFIG_HOME")); + + match prev_xdg { + Some(value) => std::env::set_var("XDG_CONFIG_HOME", value), + None => std::env::remove_var("XDG_CONFIG_HOME"), + } + } + #[test] fn profiles_dir_returns_path() { + let _lock = crate::TEST_ENV_MUTEX.lock().expect("mutex poisoned"); let dir = profiles_dir().unwrap(); assert!(dir.ends_with("profiles")); } + + #[test] + fn validate_profile_name_rejects_path_traversal() { + assert!(validate_profile_name("../escape").is_err()); + } + + #[test] + fn profile_config_path_rejects_invalid_names() { + assert!(profile_config_path("../escape").is_err()); + } + + #[test] + fn config_overrides_from_env_reads_region() { + let _lock = crate::TEST_ENV_MUTEX.lock().expect("mutex poisoned"); + std::env::set_var("AWSENC_REGION", "ap-southeast-2"); + let overrides = ConfigOverrides::from_env(); + std::env::remove_var("AWSENC_REGION"); + assert_eq!(overrides.region.as_deref(), Some("ap-southeast-2")); + } } diff --git a/awsenc-core/src/lib.rs b/awsenc-core/src/lib.rs index 3d64f95..1afe3db 100644 --- a/awsenc-core/src/lib.rs +++ b/awsenc-core/src/lib.rs @@ -8,6 +8,9 @@ pub mod sts; use thiserror::Error; +#[cfg(test)] +pub(crate) static TEST_ENV_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(()); + #[derive(Debug, Error)] pub enum Error { #[error("IO error: {0}")] diff --git a/awsenc-core/src/okta.rs b/awsenc-core/src/okta.rs index acb0fcb..c8d2ba8 100644 --- a/awsenc-core/src/okta.rs +++ b/awsenc-core/src/okta.rs @@ -2,7 +2,9 @@ use chrono::{DateTime, Utc}; use regex::Regex; use reqwest::header::{ACCEPT, CONTENT_TYPE, COOKIE}; use serde::{Deserialize, Serialize}; +use std::time::Duration; use tracing::{debug, warn}; +use url::Url; use zeroize::Zeroizing; use crate::mfa::MfaChallenge; @@ -22,6 +24,9 @@ pub struct OktaClient { base_url: String, } +const OKTA_HTTP_TIMEOUT: Duration = Duration::from_secs(30); +const MAX_SAML_REDIRECTS: usize = 10; + /// Response states from Okta's `/api/v1/authn` endpoint. #[derive(Debug)] pub enum AuthnResponse { @@ -102,29 +107,115 @@ struct VerifyPushRequest { state_token: String, } +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +struct CreateSessionRequest { + session_token: String, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct CreateSessionResponse { + id: String, + expires_at: DateTime, +} + +#[derive(Debug)] +struct HtmlFetchResult { + final_url: Url, + html: String, + redirects_followed: usize, +} + impl OktaClient { /// Create a new Okta API client for the given organization domain. /// /// `organization` should be the full Okta domain, e.g. `mycompany.okta.com`. pub fn new(organization: &str) -> Result { - let client = reqwest::Client::builder() - .redirect(reqwest::redirect::Policy::none()) - .build()?; + let organization = validate_okta_organization(organization)?; let base_url = format!("https://{organization}"); - Ok(Self { client, base_url }) + Self::with_base_url_and_timeout(&base_url, OKTA_HTTP_TIMEOUT) } /// Create an Okta client pointing at a custom base URL (for testing). pub fn with_base_url(base_url: &str) -> Result { - let client = reqwest::Client::builder() - .redirect(reqwest::redirect::Policy::none()) - .build()?; + Self::with_base_url_and_timeout(base_url, OKTA_HTTP_TIMEOUT) + } + + fn with_base_url_and_timeout(base_url: &str, timeout: Duration) -> Result { + let client = build_okta_http_client(timeout)?; Ok(Self { client, base_url: base_url.to_owned(), }) } + fn validated_app_url(&self, app_url: &str) -> Result { + validate_okta_app_url_against_base_url(&self.base_url, app_url) + } + + fn saml_url_with_session_token(&self, app_url: &str, session_token: &str) -> Result { + let mut url = self.validated_app_url(app_url)?; + url.query_pairs_mut() + .append_pair("sessionToken", session_token); + Ok(url) + } + + fn base_origin(&self) -> Result { + Ok(Url::parse(&self.base_url)?.origin()) + } + + async fn fetch_html_following_redirects( + &self, + mut url: Url, + session_id: Option<&str>, + ) -> Result { + let okta_origin = self.base_origin()?; + + for redirects_followed in 0..=MAX_SAML_REDIRECTS { + let mut request = self.client.get(url.clone()).header(ACCEPT, "text/html"); + if let Some(session_id) = session_id.filter(|_| url.origin() == okta_origin) { + request = request.header(COOKIE, format!("sid={session_id}")); + } + + let resp = request.send().await?; + let status = resp.status(); + + if status.is_redirection() { + if redirects_followed == MAX_SAML_REDIRECTS { + return Err(Error::Saml( + "too many redirects while fetching SAML assertion".into(), + )); + } + + let location = resp + .headers() + .get("location") + .ok_or_else(|| Error::Saml("missing redirect location header".into()))? + .to_str() + .map_err(|_| Error::Saml("invalid redirect location header".into()))?; + url = url.join(location)?; + debug!("following SAML redirect to {url}"); + continue; + } + + if !status.is_success() { + return Err(Error::Saml(format!( + "failed to get SAML assertion (HTTP {status})" + ))); + } + + let html = resp.text().await?; + return Ok(HtmlFetchResult { + final_url: url, + html, + redirects_followed, + }); + } + + unreachable!("redirect loop exits within MAX_SAML_REDIRECTS bounds"); + } + /// Authenticate a user with username and password. /// /// Returns `Success` with a session token, or `MfaRequired` with available factors. @@ -242,7 +333,7 @@ impl OktaClient { &self, factor_id: &str, state_token: &Zeroizing, - timeout: std::time::Duration, + timeout: Duration, ) -> Result { let start = std::time::Instant::now(); @@ -257,7 +348,7 @@ impl OktaClient { return Err(Error::Timeout("push notification timed out".into())); } debug!("push verification waiting, polling again in 2s"); - tokio::time::sleep(std::time::Duration::from_secs(2)).await; + tokio::time::sleep(Duration::from_secs(2)).await; } "REJECTED" => { return Err(Error::Mfa("push notification was rejected".into())); @@ -302,69 +393,67 @@ impl OktaClient { session_token: &Zeroizing, app_url: &str, ) -> Result { - let url = format!("{app_url}?sessionToken={}", session_token.as_str()); + let url = self.saml_url_with_session_token(app_url, session_token.as_str())?; + let response = self.fetch_html_following_redirects(url, None).await?; + extract_saml_assertion(&response.html) + } + + /// Exchange a one-time session token for a reusable Okta session cookie id. + pub async fn create_session(&self, session_token: &Zeroizing) -> Result { + let url = format!("{}/api/v1/sessions", self.base_url); + let body = CreateSessionRequest { + session_token: session_token.as_str().to_owned(), + }; let resp = self .client - .get(&url) - .header(ACCEPT, "text/html") + .post(&url) + .header(ACCEPT, "application/json") + .header(CONTENT_TYPE, "application/json") + .json(&body) .send() .await?; - // Follow the redirect chain manually if needed let status = resp.status(); - if status.is_redirection() { - // Okta might redirect; follow it - if let Some(location) = resp.headers().get("location") { - let redirect_url = location - .to_str() - .map_err(|_| Error::Saml("invalid redirect location header".into()))?; - debug!("following SAML redirect to {redirect_url}"); - let resp2 = self.client.get(redirect_url).send().await?; - let html = resp2.text().await?; - return extract_saml_assertion(&html); - } - } - + let text = resp.text().await?; if !status.is_success() { - return Err(Error::Saml(format!( - "failed to get SAML assertion (HTTP {status})" + let err_msg = parse_okta_error(&text); + return Err(Error::Auth(format!( + "Okta session creation failed (HTTP {status}): {err_msg}" ))); } - let html = resp.text().await?; - extract_saml_assertion(&html) + let created: CreateSessionResponse = serde_json::from_str(&text) + .map_err(|e| Error::Auth(format!("bad Okta session response: {e}")))?; + + Ok(OktaSession { + session_id: created.id, + expiration: created.expires_at, + }) } /// Get a SAML assertion using an existing Okta session cookie. /// /// Used when the Okta session is cached (avoids re-authentication). pub async fn get_saml_with_session(&self, session_id: &str, app_url: &str) -> Result { - let resp = self - .client - .get(app_url) - .header(ACCEPT, "text/html") - .header(COOKIE, format!("sid={session_id}")) - .send() + let app_url = self.validated_app_url(app_url)?; + let response = self + .fetch_html_following_redirects(app_url, Some(session_id)) .await?; - let status = resp.status(); - if status.is_redirection() { - // Session may be expired; Okta redirects to login - warn!("Okta session redirect (likely expired)"); - return Err(Error::Auth( - "Okta session expired (received redirect)".into(), - )); - } - - if !status.is_success() { - return Err(Error::Saml(format!( - "failed to get SAML assertion with session (HTTP {status})" - ))); + match extract_saml_assertion(&response.html) { + Ok(assertion) => Ok(assertion), + Err(_extract_error) if response.redirects_followed > 0 => { + warn!( + "Okta session redirect ended without a SAML assertion at {}", + response.final_url + ); + Err(Error::Auth( + "Okta session expired or redirected to a non-SAML page".into(), + )) + } + Err(error) => Err(error), } - - let html = resp.text().await?; - extract_saml_assertion(&html) } } @@ -372,6 +461,71 @@ impl OktaClient { // Internal helpers // --------------------------------------------------------------------------- +fn build_okta_http_client(timeout: Duration) -> Result { + Ok(reqwest::Client::builder() + .redirect(reqwest::redirect::Policy::none()) + .timeout(timeout) + .build()?) +} + +pub(crate) fn validate_okta_organization(organization: &str) -> Result { + let organization = organization.trim(); + if organization.is_empty() { + return Err(Error::Config("okta organization cannot be empty".into())); + } + if organization.contains("://") + || organization.contains('/') + || organization.contains('?') + || organization.contains('#') + || organization.contains('@') + { + return Err(Error::Config(format!( + "okta organization must be a bare host or host:port: {organization}" + ))); + } + + let parsed = Url::parse(&format!("https://{organization}"))?; + if parsed.host_str().is_none() + || !parsed.username().is_empty() + || parsed.password().is_some() + || parsed.path() != "/" + || parsed.query().is_some() + || parsed.fragment().is_some() + { + return Err(Error::Config(format!( + "okta organization must be a bare host or host:port: {organization}" + ))); + } + + Ok(organization.to_owned()) +} + +pub(crate) fn validate_okta_application_url(organization: &str, app_url: &str) -> Result { + let organization = validate_okta_organization(organization)?; + let base_url = format!("https://{organization}"); + Ok(validate_okta_app_url_against_base_url(&base_url, app_url)?.to_string()) +} + +fn validate_okta_app_url_against_base_url(base_url: &str, app_url: &str) -> Result { + let base = Url::parse(base_url)?; + let app = Url::parse(app_url)?; + + if base.scheme() == "https" && app.scheme() != "https" { + return Err(Error::Config(format!( + "okta application URL must use HTTPS: {app_url}" + ))); + } + + if base.origin() != app.origin() { + return Err(Error::Config(format!( + "okta application URL must match Okta organization origin {}: {app_url}", + base.origin().ascii_serialization() + ))); + } + + Ok(app) +} + fn parse_authn_response(text: &str) -> Result { let raw: OktaAuthnRaw = serde_json::from_str(text).map_err(|e| Error::Auth(format!("bad Okta response: {e}")))?; @@ -458,6 +612,8 @@ mod tests { #![allow(clippy::unwrap_used, clippy::panic)] use super::*; + use wiremock::matchers::{method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; #[test] fn parse_success_response() { @@ -578,6 +734,108 @@ mod tests { assert_eq!(client.base_url, "https://mycompany.okta.com"); } + #[test] + fn validate_okta_organization_accepts_host() { + let organization = validate_okta_organization("mycompany.okta.com").unwrap(); + assert_eq!(organization, "mycompany.okta.com"); + } + + #[test] + fn validate_okta_organization_rejects_path() { + let error = validate_okta_organization("mycompany.okta.com/home/app").unwrap_err(); + assert!(error.to_string().contains("bare host")); + } + + #[test] + fn validate_okta_organization_rejects_userinfo() { + let error = validate_okta_organization("user@mycompany.okta.com").unwrap_err(); + assert!(error.to_string().contains("bare host")); + } + + #[test] + fn validate_okta_application_url_accepts_same_origin_https() { + let url = validate_okta_application_url( + "mycompany.okta.com", + "https://mycompany.okta.com/home/amazon_aws/0oa123/272", + ) + .unwrap(); + + assert_eq!(url, "https://mycompany.okta.com/home/amazon_aws/0oa123/272"); + } + + #[test] + fn validate_okta_application_url_rejects_cross_origin() { + let error = validate_okta_application_url( + "mycompany.okta.com", + "https://evil.example.com/home/amazon_aws/0oa123/272", + ) + .unwrap_err(); + + assert!(error + .to_string() + .contains("must match Okta organization origin")); + } + + #[test] + fn validate_okta_application_url_rejects_cleartext() { + let error = validate_okta_application_url( + "mycompany.okta.com", + "http://mycompany.okta.com/home/amazon_aws/0oa123/272", + ) + .unwrap_err(); + + assert!(error.to_string().contains("must use HTTPS")); + } + + #[test] + fn saml_url_with_session_token_preserves_existing_query() { + let client = OktaClient::new("mycompany.okta.com").unwrap(); + let url = client + .saml_url_with_session_token( + "https://mycompany.okta.com/home/amazon_aws/0oa123/272?fromHome=true", + "session-token-123", + ) + .unwrap(); + + assert_eq!( + url.query(), + Some("fromHome=true&sessionToken=session-token-123") + ); + } + + #[tokio::test] + async fn okta_client_timeout_is_bounded() { + let server = MockServer::start().await; + + Mock::given(method("POST")) + .and(path("/api/v1/authn")) + .respond_with( + ResponseTemplate::new(200) + .set_delay(Duration::from_millis(200)) + .set_body_json(serde_json::json!({ + "status": "SUCCESS", + "sessionToken": "slow-token" + })), + ) + .expect(1) + .mount(&server) + .await; + + let client = + OktaClient::with_base_url_and_timeout(&server.uri(), Duration::from_millis(50)) + .unwrap(); + let password = Zeroizing::new("password".to_string()); + let error = client + .authenticate("user@example.com", &password) + .await + .unwrap_err(); + + assert!( + matches!(error, Error::Http(_) | Error::Timeout(_)), + "unexpected error: {error}" + ); + } + #[test] fn parse_unknown_status() { let json = r#"{"status": "LOCKED_OUT"}"#; diff --git a/awsenc-core/src/profile.rs b/awsenc-core/src/profile.rs index 7c363c9..e57e1e0 100644 --- a/awsenc-core/src/profile.rs +++ b/awsenc-core/src/profile.rs @@ -87,16 +87,15 @@ pub fn list_profiles() -> Result> { /// Check if a profile config file exists. pub fn profile_exists(name: &str) -> bool { - config::profiles_dir() - .map(|dir| dir.join(format!("{name}.toml")).exists()) + config::profile_config_path(name) + .map(|path| path.exists()) .unwrap_or(false) } /// Delete a profile's config and cache files. pub fn delete_profile(name: &str) -> Result<()> { // Delete profile config - let profiles_dir = config::profiles_dir()?; - let config_path = profiles_dir.join(format!("{name}.toml")); + let config_path = config::profile_config_path(name)?; if config_path.exists() { std::fs::remove_file(&config_path)?; } else { @@ -118,11 +117,13 @@ mod tests { #[test] fn profile_exists_returns_false_for_missing() { + let _lock = crate::TEST_ENV_MUTEX.lock().expect("mutex poisoned"); assert!(!profile_exists("nonexistent-profile-xyz-12345")); } #[test] fn list_profiles_returns_empty_for_new_install() { + let _lock = crate::TEST_ENV_MUTEX.lock().expect("mutex poisoned"); // This test relies on the profiles directory existing but potentially // being empty. It should at minimum not error. let result = list_profiles(); @@ -148,6 +149,13 @@ mod tests { #[test] fn roundtrip_profile_config_and_check_exists() { + let _lock = crate::TEST_ENV_MUTEX.lock().expect("mutex poisoned"); + let tmp = tempfile::tempdir().unwrap(); + let prev_home = std::env::var("HOME").ok(); + let prev_xdg = std::env::var("XDG_CONFIG_HOME").ok(); + std::env::set_var("HOME", tmp.path()); + std::env::set_var("XDG_CONFIG_HOME", tmp.path().join(".config")); + // Create a temp profile, verify it exists, then clean up let name = "test-roundtrip-profile-awsenc"; let config = ProfileConfig { @@ -158,6 +166,7 @@ mod tests { factor: None, duration: None, }, + region: Some("us-west-1".into()), secondary_role: None, }; @@ -172,11 +181,26 @@ mod tests { // Delete delete_profile(name).unwrap(); assert!(!profile_exists(name)); + match prev_home { + Some(v) => std::env::set_var("HOME", v), + None => std::env::remove_var("HOME"), + } + match prev_xdg { + Some(v) => std::env::set_var("XDG_CONFIG_HOME", v), + None => std::env::remove_var("XDG_CONFIG_HOME"), + } } #[test] fn delete_nonexistent_profile_errors() { + let _lock = crate::TEST_ENV_MUTEX.lock().expect("mutex poisoned"); let result = delete_profile("definitely-does-not-exist-xyz"); assert!(result.is_err()); } + + #[test] + fn invalid_profile_name_is_rejected_consistently() { + assert!(!profile_exists("../escape")); + assert!(delete_profile("../escape").is_err()); + } } diff --git a/awsenc-core/src/sts.rs b/awsenc-core/src/sts.rs index cdd356e..eaeefa4 100644 --- a/awsenc-core/src/sts.rs +++ b/awsenc-core/src/sts.rs @@ -1,6 +1,7 @@ use base64::Engine; use chrono::{DateTime, Utc}; -use regex::Regex; +use roxmltree::{Document, Node}; +use std::time::Duration; use tracing::debug; use zeroize::Zeroizing; @@ -21,19 +22,22 @@ pub struct StsClient { endpoint_url: String, } +const STS_HTTP_TIMEOUT: Duration = Duration::from_secs(30); + impl StsClient { /// Create a new STS client with the default endpoint. pub fn new() -> Self { - Self { - client: reqwest::Client::new(), - endpoint_url: "https://sts.amazonaws.com".to_owned(), - } + Self::with_endpoint_and_timeout("https://sts.amazonaws.com", STS_HTTP_TIMEOUT) } /// Create a new STS client with a custom endpoint URL (for testing). pub fn with_endpoint(endpoint_url: &str) -> Self { + Self::with_endpoint_and_timeout(endpoint_url, STS_HTTP_TIMEOUT) + } + + fn with_endpoint_and_timeout(endpoint_url: &str, timeout: Duration) -> Self { Self { - client: reqwest::Client::new(), + client: build_sts_http_client(timeout), endpoint_url: endpoint_url.to_owned(), } } @@ -68,8 +72,8 @@ impl StsClient { let body = resp.text().await?; if !status.is_success() { - let error_msg = extract_xml_tag(&body, "Message") - .or_else(|| extract_xml_tag(&body, "Error")) + let error_msg = find_text_by_local_name(&body, "Message") + .or_else(|| find_text_by_local_name(&body, "Error")) .unwrap_or_else(|| body.chars().take(500).collect()); return Err(Error::Sts(format!( "AssumeRoleWithSAML failed (HTTP {status}): {error_msg}" @@ -86,18 +90,25 @@ impl Default for StsClient { } } +fn build_sts_http_client(timeout: Duration) -> reqwest::Client { + reqwest::Client::builder() + .timeout(timeout) + .build() + .expect("failed to build STS HTTP client") +} + /// Parse the `AssumeRoleWithSAML` XML response into `AwsCredentials`. fn parse_assume_role_response(xml: &str) -> Result { - let access_key_id = extract_xml_tag(xml, "AccessKeyId") + let access_key_id = find_text_by_local_name(xml, "AccessKeyId") .ok_or_else(|| Error::Sts("missing AccessKeyId in STS response".into()))?; - let secret_access_key = extract_xml_tag(xml, "SecretAccessKey") + let secret_access_key = find_text_by_local_name(xml, "SecretAccessKey") .ok_or_else(|| Error::Sts("missing SecretAccessKey in STS response".into()))?; - let session_token = extract_xml_tag(xml, "SessionToken") + let session_token = find_text_by_local_name(xml, "SessionToken") .ok_or_else(|| Error::Sts("missing SessionToken in STS response".into()))?; - let expiration_str = extract_xml_tag(xml, "Expiration") + let expiration_str = find_text_by_local_name(xml, "Expiration") .ok_or_else(|| Error::Sts("missing Expiration in STS response".into()))?; let expiration: DateTime = expiration_str @@ -112,14 +123,12 @@ fn parse_assume_role_response(xml: &str) -> Result { }) } -/// Extract the text content of an XML tag using regex. -/// -/// This is intentionally simple -- the STS response format is predictable and -/// adding a full XML parser dependency is unnecessary. -fn extract_xml_tag(xml: &str, tag: &str) -> Option { - let pattern = format!("<{tag}>([^<]*)"); - let re = Regex::new(&pattern).ok()?; - re.captures(xml).map(|caps| caps[1].to_owned()) +fn find_text_by_local_name(xml: &str, tag: &str) -> Option { + let doc = Document::parse(xml).ok()?; + doc.descendants() + .find(|node| node.is_element() && node.tag_name().name() == tag) + .and_then(text_content) + .map(str::to_owned) } /// Parse available roles from a base64-encoded SAML assertion. @@ -132,48 +141,51 @@ pub fn parse_saml_roles(saml_assertion: &str) -> Result> { let decoded_bytes = base64::engine::general_purpose::STANDARD.decode(saml_assertion)?; let decoded = String::from_utf8(decoded_bytes) .map_err(|e| Error::Saml(format!("SAML assertion is not valid UTF-8: {e}")))?; - - // Extract the Role attribute values from the SAML XML. - // The attribute looks like: - // - // arn:aws:iam::123:role/Role,arn:aws:iam::123:saml-provider/Okta - // - let re = Regex::new(r"<(?:\w+:)?AttributeValue[^>]*>([^<]*)")?; - - // Find the Role attribute block - let role_attr_re = Regex::new( - r#"(?s)<(?:\w+:)?Attribute[^>]+Name\s*=\s*"https://aws\.amazon\.com/SAML/Attributes/Role"[^>]*>(.*?)"#, - )?; - - let role_block = role_attr_re - .captures(&decoded) - .ok_or_else(|| Error::Saml("no Role attribute found in SAML assertion".into()))?; - - let block_text = &role_block[1]; + let doc = + Document::parse(&decoded).map_err(|e| Error::Saml(format!("invalid SAML XML: {e}")))?; let mut roles = Vec::new(); - - for caps in re.captures_iter(block_text) { - let value = caps[1].trim(); - if value.is_empty() { - continue; + let mut found_role_attribute = false; + + for attribute in doc.descendants().filter(|node| { + node.is_element() + && node.tag_name().name() == "Attribute" + && node.attribute("Name") == Some("https://aws.amazon.com/SAML/Attributes/Role") + }) { + found_role_attribute = true; + for value_node in attribute + .children() + .filter(|node| node.is_element() && node.tag_name().name() == "AttributeValue") + { + let Some(value) = text_content(value_node).map(str::trim) else { + continue; + }; + if value.is_empty() { + continue; + } + + let parts: Vec<&str> = value.split(',').collect(); + if parts.len() != 2 { + debug!("skipping malformed role attribute value: {value}"); + continue; + } + + let (role_arn, principal_arn) = if parts[0].contains(":role/") { + (parts[0].trim().to_owned(), parts[1].trim().to_owned()) + } else { + (parts[1].trim().to_owned(), parts[0].trim().to_owned()) + }; + + roles.push(SamlRole { + role_arn, + principal_arn, + }); } + } - let parts: Vec<&str> = value.split(',').collect(); - if parts.len() != 2 { - debug!("skipping malformed role attribute value: {value}"); - continue; - } - - let (role_arn, principal_arn) = if parts[0].contains(":role/") { - (parts[0].trim().to_owned(), parts[1].trim().to_owned()) - } else { - (parts[1].trim().to_owned(), parts[0].trim().to_owned()) - }; - - roles.push(SamlRole { - role_arn, - principal_arn, - }); + if !found_role_attribute { + return Err(Error::Saml( + "no Role attribute found in SAML assertion".into(), + )); } if roles.is_empty() { @@ -183,17 +195,32 @@ pub fn parse_saml_roles(saml_assertion: &str) -> Result> { Ok(roles) } +fn text_content<'input>(node: Node<'input, 'input>) -> Option<&'input str> { + node.text().or_else(|| { + if node + .children() + .all(|child| !child.is_text() && !child.is_element()) + { + Some("") + } else { + None + } + }) +} + #[cfg(test)] mod tests { #![allow(clippy::unwrap_used)] use super::*; + use wiremock::matchers::{body_string_contains, method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; #[test] fn extract_xml_tag_found() { let xml = "AKIAIOSFODNN7EXAMPLE"; assert_eq!( - extract_xml_tag(xml, "AccessKeyId"), + find_text_by_local_name(xml, "AccessKeyId"), Some("AKIAIOSFODNN7EXAMPLE".to_owned()) ); } @@ -201,13 +228,16 @@ mod tests { #[test] fn extract_xml_tag_not_found() { let xml = "value"; - assert_eq!(extract_xml_tag(xml, "AccessKeyId"), None); + assert_eq!(find_text_by_local_name(xml, "AccessKeyId"), None); } #[test] fn extract_xml_tag_empty() { let xml = ""; - assert_eq!(extract_xml_tag(xml, "AccessKeyId"), Some(String::new())); + assert_eq!( + find_text_by_local_name(xml, "AccessKeyId"), + Some(String::new()) + ); } #[test] @@ -311,6 +341,44 @@ mod tests { assert!(parse_saml_roles(&b64).is_err()); } + #[test] + fn parse_assume_role_response_with_namespace() { + let xml = r#" + + + + ASIAXMLNS + secret + token + 2026-04-11T16:30:00Z + + + + "#; + + let creds = parse_assume_role_response(xml).unwrap(); + assert_eq!(creds.access_key_id, "ASIAXMLNS"); + } + + #[test] + fn parse_saml_roles_handles_namespaced_attributes() { + let saml_xml = r#" + + + + + arn:aws:iam::123456789012:saml-provider/Okta,arn:aws:iam::123456789012:role/Admin + + + + + "#; + + let b64 = base64::engine::general_purpose::STANDARD.encode(saml_xml); + let roles = parse_saml_roles(&b64).unwrap(); + assert_eq!(roles[0].role_arn, "arn:aws:iam::123456789012:role/Admin"); + } + #[test] fn parse_saml_roles_bad_base64() { assert!(parse_saml_roles("not-valid-base64!!!").is_err()); @@ -327,4 +395,37 @@ mod tests { let client = StsClient::with_endpoint("http://localhost:4566"); assert_eq!(client.endpoint_url, "http://localhost:4566"); } + + #[tokio::test] + async fn sts_client_timeout_is_bounded() { + let server = MockServer::start().await; + + Mock::given(method("POST")) + .and(path("/")) + .and(body_string_contains("Action=AssumeRoleWithSAML")) + .respond_with( + ResponseTemplate::new(200) + .set_delay(Duration::from_millis(200)) + .set_body_string(""), + ) + .expect(1) + .mount(&server) + .await; + + let client = StsClient::with_endpoint_and_timeout(&server.uri(), Duration::from_millis(50)); + let error = client + .assume_role_with_saml( + "arn:aws:iam::123456789012:role/TestRole", + "arn:aws:iam::123456789012:saml-provider/Okta", + "base64-saml-assertion", + 3600, + ) + .await + .unwrap_err(); + + assert!( + matches!(error, Error::Http(_) | Error::Timeout(_)), + "unexpected error: {error}" + ); + } } diff --git a/awsenc-core/tests/cache_disk_tests.rs b/awsenc-core/tests/cache_disk_tests.rs index e725427..3cf91f6 100644 --- a/awsenc-core/tests/cache_disk_tests.rs +++ b/awsenc-core/tests/cache_disk_tests.rs @@ -295,6 +295,7 @@ fn config_profile_save_load_roundtrip() { factor: Some("yubikey".into()), duration: Some(7200), }, + region: Some("us-west-2".into()), secondary_role: Some(SecondaryRoleConfig { role_arn: "arn:aws:iam::987654321098:role/CrossAccount".into(), }), @@ -320,6 +321,7 @@ fn config_profile_save_load_roundtrip() { ); assert_eq!(loaded.okta.factor.as_deref(), Some("yubikey")); assert_eq!(loaded.okta.duration, Some(7200)); + assert_eq!(loaded.region.as_deref(), Some("us-west-2")); assert_eq!( loaded .secondary_role @@ -342,6 +344,7 @@ fn config_profile_minimal_roundtrip() { factor: None, duration: None, }, + region: None, secondary_role: None, }; diff --git a/awsenc-core/tests/http_mock_tests.rs b/awsenc-core/tests/http_mock_tests.rs index 125c0f6..f78563d 100644 --- a/awsenc-core/tests/http_mock_tests.rs +++ b/awsenc-core/tests/http_mock_tests.rs @@ -3,8 +3,8 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Duration; -use wiremock::matchers::{body_string_contains, header, method, path, path_regex}; -use wiremock::{Mock, MockServer, Request, Respond, ResponseTemplate}; +use wiremock::matchers::{body_string_contains, header, method, path, path_regex, query_param}; +use wiremock::{Match, Mock, MockServer, Request, Respond, ResponseTemplate}; use zeroize::Zeroizing; use awsenc_core::okta::{AuthnResponse, OktaClient}; @@ -33,6 +33,14 @@ impl Respond for SequentialResponder { } } +struct MissingCookieMatcher; + +impl Match for MissingCookieMatcher { + fn matches(&self, request: &Request) -> bool { + !request.headers.contains_key("cookie") + } +} + // =========================================================================== // Okta API tests // =========================================================================== @@ -335,6 +343,7 @@ async fn okta_saml_assertion_extraction() { Mock::given(method("GET")) .and(path("/home/amazon_aws/0oa123abc/272")) + .and(query_param("sessionToken", "session-for-saml")) .respond_with(ResponseTemplate::new(200).set_body_string(saml_html)) .expect(1) .mount(&server) @@ -351,6 +360,51 @@ async fn okta_saml_assertion_extraction() { assert_eq!(result, "PHNhbWw+dGVzdGRhdGE8L3NhbWw+"); } +#[tokio::test] +async fn okta_saml_assertion_follows_multiple_redirects() { + let server = MockServer::start().await; + + let saml_html = r#" + +
+ +
+ +"#; + + Mock::given(method("GET")) + .and(path("/home/amazon_aws/0oa123abc/272")) + .and(query_param("sessionToken", "session-for-saml")) + .respond_with(ResponseTemplate::new(302).insert_header("location", "/app/step-one")) + .expect(1) + .mount(&server) + .await; + + Mock::given(method("GET")) + .and(path("/app/step-one")) + .respond_with(ResponseTemplate::new(302).insert_header("location", "/app/step-two")) + .expect(1) + .mount(&server) + .await; + + Mock::given(method("GET")) + .and(path("/app/step-two")) + .respond_with(ResponseTemplate::new(200).set_body_string(saml_html)) + .expect(1) + .mount(&server) + .await; + + let client = OktaClient::with_base_url(&server.uri()).unwrap(); + let session_token = Zeroizing::new("session-for-saml".to_string()); + let app_url = format!("{}/home/amazon_aws/0oa123abc/272", server.uri()); + let result = client + .get_saml_assertion(&session_token, &app_url) + .await + .unwrap(); + + assert_eq!(result, "bXVsdGktaG9wLXNhbWw="); +} + #[tokio::test] async fn okta_saml_assertion_missing() { let server = MockServer::start().await; @@ -411,16 +465,86 @@ async fn okta_session_based_saml() { assert_eq!(result, "c2Vzc2lvbi1iYXNlZC1zYW1s"); } +#[tokio::test] +async fn okta_session_based_saml_follows_cross_origin_redirect_without_cookie() { + let server = MockServer::start().await; + let federated_server = MockServer::start().await; + + let saml_html = r#" + +
+ +
+ +"#; + + Mock::given(method("GET")) + .and(path("/home/amazon_aws/0oa_session/272")) + .and(header("cookie", "sid=cached-session-id-xyz")) + .respond_with(ResponseTemplate::new(302).insert_header( + "location", + format!("{}/federated/saml", federated_server.uri()), + )) + .expect(1) + .mount(&server) + .await; + + Mock::given(method("GET")) + .and(path("/federated/saml")) + .and(MissingCookieMatcher) + .respond_with(ResponseTemplate::new(200).set_body_string(saml_html)) + .expect(1) + .mount(&federated_server) + .await; + + let client = OktaClient::with_base_url(&server.uri()).unwrap(); + let app_url = format!("{}/home/amazon_aws/0oa_session/272", server.uri()); + let result = client + .get_saml_with_session("cached-session-id-xyz", &app_url) + .await + .unwrap(); + + assert_eq!(result, "Y3Jvc3Mtb3JpZ2luLXNhbWw="); +} + +#[tokio::test] +async fn okta_create_session_returns_cookie_id() { + let server = MockServer::start().await; + + Mock::given(method("POST")) + .and(path("/api/v1/sessions")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "id": "sid-real-session-id", + "expiresAt": "2026-04-11T20:00:00Z" + }))) + .expect(1) + .mount(&server) + .await; + + let client = OktaClient::with_base_url(&server.uri()).unwrap(); + let session_token = Zeroizing::new("one-time-session-token".to_string()); + let session = client.create_session(&session_token).await.unwrap(); + + assert_eq!(session.session_id, "sid-real-session-id"); + assert_eq!(session.expiration.to_rfc3339(), "2026-04-11T20:00:00+00:00"); +} + #[tokio::test] async fn okta_session_expired_redirect() { let server = MockServer::start().await; + let login_html = "

Please sign in

"; - // When the session is expired, Okta returns a redirect. Mock::given(method("GET")) .and(path("/home/amazon_aws/0oa_expired/272")) - .respond_with( - ResponseTemplate::new(302).insert_header("location", "https://login.okta.com"), - ) + .respond_with(ResponseTemplate::new(302).insert_header("location", "/login/signin")) + .expect(1) + .mount(&server) + .await; + + Mock::given(method("GET")) + .and(path("/login/signin")) + .and(header("cookie", "sid=expired-session-id")) + .respond_with(ResponseTemplate::new(200).set_body_string(login_html)) .expect(1) .mount(&server) .await; diff --git a/awsenc-tpm-bridge/src/main.rs b/awsenc-tpm-bridge/src/main.rs index 1fef0ab..9519b41 100644 --- a/awsenc-tpm-bridge/src/main.rs +++ b/awsenc-tpm-bridge/src/main.rs @@ -4,17 +4,62 @@ mod tpm; use base64::prelude::*; -use enclaveapp_bridge::{BridgeRequest, BridgeResponse}; +use enclaveapp_bridge::BridgeResponse; +use serde::Deserialize; use std::io::{self, BufRead, Write}; +const DEFAULT_APP_NAME: &str = "awsenc"; +const DEFAULT_KEY_LABEL: &str = "cache-key"; + +#[derive(Debug, Deserialize)] +struct BridgeRequestCompat { + method: String, + #[serde(default)] + params: BridgeParamsCompat, +} + +#[derive(Debug, Default, Deserialize)] +struct BridgeParamsCompat { + #[serde(default)] + data: String, + #[serde(default)] + biometric: bool, + #[serde(default)] + app_name: String, + #[serde(default)] + key_label: String, +} + +impl BridgeParamsCompat { + fn app_name(&self) -> &str { + if self.app_name.is_empty() { + DEFAULT_APP_NAME + } else { + &self.app_name + } + } + + fn key_label(&self) -> &str { + if self.key_label.is_empty() { + DEFAULT_KEY_LABEL + } else { + &self.key_label + } + } +} + fn handle_request( - request: &BridgeRequest, + request: &BridgeRequestCompat, storage: &mut Option, ) -> BridgeResponse { match request.method.as_str() { "init" => { let biometric = request.params.biometric; - match tpm::TpmStorage::new(biometric) { + match tpm::TpmStorage::new( + request.params.app_name(), + request.params.key_label(), + biometric, + ) { Ok(s) => { *storage = Some(s); BridgeResponse::success("ok") @@ -58,9 +103,14 @@ fn handle_request( Err(e) => BridgeResponse::error(&format!("decrypt failed: {e}")), } } - "destroy" => { - *storage = None; - BridgeResponse::success("ok") + "destroy" | "delete" => { + match tpm::TpmStorage::delete(request.params.app_name(), request.params.key_label()) { + Ok(()) => { + *storage = None; + BridgeResponse::success("ok") + } + Err(e) => BridgeResponse::error(&format!("delete failed: {e}")), + } } other => BridgeResponse::error(&format!("unknown method: {other}")), } @@ -87,7 +137,7 @@ fn main() { continue; } - let response = match serde_json::from_str::(&line) { + let response = match serde_json::from_str::(&line) { Ok(req) => handle_request(&req, &mut storage), Err(e) => BridgeResponse::error(&format!("invalid JSON: {e}")), }; @@ -108,15 +158,15 @@ fn main() { #[allow(clippy::unwrap_used, clippy::panic)] mod tests { use super::*; - use enclaveapp_bridge::BridgeParams; - fn make_request(method: &str, data: &str, biometric: bool) -> BridgeRequest { - BridgeRequest { + fn make_request(method: &str, data: &str, biometric: bool) -> BridgeRequestCompat { + BridgeRequestCompat { method: method.to_string(), - params: BridgeParams { + params: BridgeParamsCompat { data: data.to_string(), biometric, - app_name: "awsenc".to_string(), + app_name: DEFAULT_APP_NAME.to_string(), + key_label: DEFAULT_KEY_LABEL.to_string(), }, } } @@ -124,24 +174,28 @@ mod tests { #[test] fn parse_init_request() { let json = r#"{"method": "init", "params": {"biometric": false}}"#; - let req: BridgeRequest = serde_json::from_str(json).unwrap(); + let req: BridgeRequestCompat = serde_json::from_str(json).unwrap(); assert_eq!(req.method, "init"); assert!(!req.params.biometric); + assert_eq!(req.params.app_name(), DEFAULT_APP_NAME); + assert_eq!(req.params.key_label(), DEFAULT_KEY_LABEL); } #[test] fn parse_init_request_defaults() { let json = r#"{"method": "init", "params": {}}"#; - let req: BridgeRequest = serde_json::from_str(json).unwrap(); + let req: BridgeRequestCompat = serde_json::from_str(json).unwrap(); assert_eq!(req.method, "init"); assert!(!req.params.biometric); assert!(req.params.data.is_empty()); + assert_eq!(req.params.app_name(), DEFAULT_APP_NAME); + assert_eq!(req.params.key_label(), DEFAULT_KEY_LABEL); } #[test] fn parse_encrypt_request() { let json = r#"{"method": "encrypt", "params": {"data": "aGVsbG8=", "biometric": false}}"#; - let req: BridgeRequest = serde_json::from_str(json).unwrap(); + let req: BridgeRequestCompat = serde_json::from_str(json).unwrap(); assert_eq!(req.method, "encrypt"); assert_eq!(req.params.data, "aGVsbG8="); } @@ -149,7 +203,7 @@ mod tests { #[test] fn parse_decrypt_request() { let json = r#"{"method": "decrypt", "params": {"data": "Y2lwaGVy"}}"#; - let req: BridgeRequest = serde_json::from_str(json).unwrap(); + let req: BridgeRequestCompat = serde_json::from_str(json).unwrap(); assert_eq!(req.method, "decrypt"); assert_eq!(req.params.data, "Y2lwaGVy"); } @@ -157,8 +211,26 @@ mod tests { #[test] fn parse_destroy_request() { let json = r#"{"method": "destroy", "params": {}}"#; - let req: BridgeRequest = serde_json::from_str(json).unwrap(); + let req: BridgeRequestCompat = serde_json::from_str(json).unwrap(); assert_eq!(req.method, "destroy"); + assert_eq!(req.params.app_name(), DEFAULT_APP_NAME); + assert_eq!(req.params.key_label(), DEFAULT_KEY_LABEL); + } + + #[test] + fn parse_delete_request() { + let json = r#"{"method": "delete", "params": {"key_label": "cache-key"}}"#; + let req: BridgeRequestCompat = serde_json::from_str(json).unwrap(); + assert_eq!(req.method, "delete"); + assert_eq!(req.params.key_label(), DEFAULT_KEY_LABEL); + } + + #[test] + fn parse_request_uses_binary_defaults_for_legacy_payloads() { + let json = r#"{"method":"init","params":{"biometric":true}}"#; + let req: BridgeRequestCompat = serde_json::from_str(json).unwrap(); + assert_eq!(req.params.app_name(), DEFAULT_APP_NAME); + assert_eq!(req.params.key_label(), DEFAULT_KEY_LABEL); } #[test] @@ -194,7 +266,22 @@ mod tests { let req = make_request("destroy", "", false); let mut storage = None; let resp = handle_request(&req, &mut storage); - assert!(resp.result.is_some()); + assert!( + resp.result.is_some() || resp.error.is_some(), + "destroy should return a structured response" + ); + assert!(storage.is_none()); + } + + #[test] + fn handle_delete_clears_storage() { + let req = make_request("delete", "", false); + let mut storage = None; + let resp = handle_request(&req, &mut storage); + assert!( + resp.result.is_some() || resp.error.is_some(), + "delete should return a structured response" + ); assert!(storage.is_none()); } @@ -236,7 +323,7 @@ mod tests { let req = make_request("encrypt", "", false); // On platforms without a TPM, new() may fail and storage is None, // so we get "not initialized" instead of "missing data". Both are valid errors. - let mut storage = tpm::TpmStorage::new(false).ok(); + let mut storage = tpm::TpmStorage::new("awsenc", "cache-key", false).ok(); let resp = handle_request(&req, &mut storage); assert!(resp.error.is_some()); } @@ -244,7 +331,7 @@ mod tests { #[test] fn handle_encrypt_invalid_base64() { let req = make_request("encrypt", "not-valid-base64!!!", false); - let mut storage = tpm::TpmStorage::new(false).ok(); + let mut storage = tpm::TpmStorage::new("awsenc", "cache-key", false).ok(); let resp = handle_request(&req, &mut storage); assert!(resp.error.is_some()); } @@ -252,7 +339,7 @@ mod tests { #[test] fn handle_decrypt_missing_data() { let req = make_request("decrypt", "", false); - let mut storage = tpm::TpmStorage::new(false).ok(); + let mut storage = tpm::TpmStorage::new("awsenc", "cache-key", false).ok(); let resp = handle_request(&req, &mut storage); assert!(resp.error.is_some()); } @@ -260,7 +347,7 @@ mod tests { #[cfg(not(target_os = "windows"))] #[test] fn encrypt_returns_platform_error_on_non_windows() { - let storage = tpm::TpmStorage::new(false).unwrap(); + let storage = tpm::TpmStorage::new("awsenc", "cache-key", false).unwrap(); let result = storage.encrypt(b"hello"); assert!(result.is_err()); assert!(result.unwrap_err().contains("only supported on Windows")); @@ -269,7 +356,7 @@ mod tests { #[cfg(not(target_os = "windows"))] #[test] fn decrypt_returns_platform_error_on_non_windows() { - let storage = tpm::TpmStorage::new(false).unwrap(); + let storage = tpm::TpmStorage::new("awsenc", "cache-key", false).unwrap(); let result = storage.decrypt(b"hello"); assert!(result.is_err()); assert!(result.unwrap_err().contains("only supported on Windows")); @@ -278,15 +365,15 @@ mod tests { #[test] fn roundtrip_json_protocol() { // Simulate the full JSON protocol flow - let init_json = r#"{"method":"init","params":{"biometric":false}}"#; - let encrypt_json = - r#"{"method":"encrypt","params":{"data":"aGVsbG8gd29ybGQ=","biometric":false}}"#; - let destroy_json = r#"{"method":"destroy","params":{}}"#; + let init_json = r#"{"method":"init","params":{"app_name":"awsenc","key_label":"cache-key","biometric":false}}"#; + let encrypt_json = r#"{"method":"encrypt","params":{"data":"aGVsbG8gd29ybGQ=","app_name":"awsenc","key_label":"cache-key","biometric":false}}"#; + let destroy_json = + r#"{"method":"destroy","params":{"app_name":"awsenc","key_label":"cache-key"}}"#; let mut storage = None; // Init - let req: BridgeRequest = serde_json::from_str(init_json).unwrap(); + let req: BridgeRequestCompat = serde_json::from_str(init_json).unwrap(); let resp = handle_request(&req, &mut storage); let resp_json = serde_json::to_string(&resp).unwrap(); assert!( @@ -295,7 +382,7 @@ mod tests { ); // Encrypt (will fail on non-Windows, which is expected) - let req: BridgeRequest = serde_json::from_str(encrypt_json).unwrap(); + let req: BridgeRequestCompat = serde_json::from_str(encrypt_json).unwrap(); let resp = handle_request(&req, &mut storage); let resp_json = serde_json::to_string(&resp).unwrap(); assert!( @@ -304,16 +391,19 @@ mod tests { ); // Destroy - let req: BridgeRequest = serde_json::from_str(destroy_json).unwrap(); + let req: BridgeRequestCompat = serde_json::from_str(destroy_json).unwrap(); let resp = handle_request(&req, &mut storage); - assert!(resp.result.is_some()); + assert!( + resp.result.is_some() || resp.error.is_some(), + "destroy should return a structured response" + ); assert!(storage.is_none()); } #[test] fn invalid_json_produces_error() { let bad_json = "this is not json"; - let result = serde_json::from_str::(bad_json); + let result = serde_json::from_str::(bad_json); assert!(result.is_err()); } } diff --git a/awsenc-tpm-bridge/src/tpm.rs b/awsenc-tpm-bridge/src/tpm.rs index ab3fbeb..5407f13 100644 --- a/awsenc-tpm-bridge/src/tpm.rs +++ b/awsenc-tpm-bridge/src/tpm.rs @@ -8,59 +8,113 @@ //! //! On non-Windows platforms, all operations return an error at runtime. +use enclaveapp_core::metadata; +use enclaveapp_core::traits::{EnclaveEncryptor, EnclaveKeyManager}; +use enclaveapp_core::types::{AccessPolicy, KeyType}; +use std::path::Path; + +#[cfg_attr(not(any(test, target_os = "windows")), allow(dead_code))] +fn requested_policy(biometric: bool) -> AccessPolicy { + if biometric { + AccessPolicy::BiometricOnly + } else { + AccessPolicy::None + } +} + +#[cfg_attr(not(any(test, target_os = "windows")), allow(dead_code))] +fn existing_policy(keys_dir: &Path, key_label: &str) -> Option { + let meta_path = keys_dir.join(format!("{key_label}.meta")); + if !meta_path.exists() { + return None; + } + metadata::load_meta(keys_dir, key_label) + .ok() + .map(|meta| meta.access_policy) +} + +#[cfg_attr(not(any(test, target_os = "windows")), allow(dead_code))] +fn ensure_key( + encryptor: &E, + keys_dir: &Path, + key_label: &str, + biometric: bool, +) -> Result<(), String> +where + E: EnclaveEncryptor + EnclaveKeyManager, +{ + let policy = requested_policy(biometric); + + if encryptor.public_key(key_label).is_ok() { + match existing_policy(keys_dir, key_label) { + Some(existing) if existing != policy => { + encryptor + .delete_key(key_label) + .map_err(|e| format!("key deletion failed: {e}"))?; + } + _ => return Ok(()), + } + } + + encryptor + .generate(key_label, KeyType::Encryption, policy) + .map_err(|e| format!("key generation failed: {e}"))?; + Ok(()) +} + #[cfg(target_os = "windows")] mod platform { + use super::{ensure_key, metadata}; use enclaveapp_core::traits::{EnclaveEncryptor, EnclaveKeyManager}; - use enclaveapp_core::types::{AccessPolicy, KeyType}; use enclaveapp_windows::TpmEncryptor; - /// Application name for key namespacing. - const APP_NAME: &str = "awsenc"; - - /// Key label for the TPM bridge encryption key. - const KEY_LABEL: &str = "cache-key"; - pub struct TpmStorage { encryptor: TpmEncryptor, - #[allow(dead_code)] - biometric: bool, + key_label: String, } impl TpmStorage { - pub fn new(biometric: bool) -> Result { - let encryptor = TpmEncryptor::new(APP_NAME); + pub fn new(app_name: &str, key_label: &str, biometric: bool) -> Result { + let encryptor = TpmEncryptor::new(app_name); if !encryptor.is_available() { return Err("TPM not available".to_string()); } - // Ensure the key exists; generate if missing. - if encryptor.public_key(KEY_LABEL).is_err() { - let policy = if biometric { - AccessPolicy::BiometricOnly - } else { - AccessPolicy::None - }; - encryptor - .generate(KEY_LABEL, KeyType::Encryption, policy) - .map_err(|e| format!("key generation failed: {e}"))?; - } + ensure_key( + &encryptor, + &metadata::keys_dir(app_name), + key_label, + biometric, + )?; Ok(Self { encryptor, - biometric, + key_label: key_label.to_string(), }) } + pub fn delete(app_name: &str, key_label: &str) -> Result<(), String> { + let encryptor = TpmEncryptor::new(app_name); + + if !encryptor.is_available() { + return Err("TPM not available".to_string()); + } + + encryptor + .delete_key(key_label) + .map_err(|e| format!("key delete failed: {e}")) + } + pub fn encrypt(&self, plaintext: &[u8]) -> Result, String> { self.encryptor - .encrypt(KEY_LABEL, plaintext) + .encrypt(&self.key_label, plaintext) .map_err(|e| e.to_string()) } pub fn decrypt(&self, ciphertext: &[u8]) -> Result, String> { self.encryptor - .decrypt(KEY_LABEL, ciphertext) + .decrypt(&self.key_label, ciphertext) .map_err(|e| e.to_string()) } } @@ -69,17 +123,26 @@ mod platform { #[cfg(not(target_os = "windows"))] mod platform { pub struct TpmStorage { + _app_name: String, + _key_label: String, _biometric: bool, } impl TpmStorage { #[allow(clippy::unnecessary_wraps)] - pub fn new(biometric: bool) -> Result { + pub fn new(app_name: &str, key_label: &str, biometric: bool) -> Result { Ok(Self { + _app_name: app_name.to_string(), + _key_label: key_label.to_string(), _biometric: biometric, }) } + #[allow(clippy::unnecessary_wraps)] + pub fn delete(_app_name: &str, _key_label: &str) -> Result<(), String> { + Ok(()) + } + #[allow(clippy::unused_self)] pub fn encrypt(&self, _plaintext: &[u8]) -> Result, String> { Err("TPM bridge is only supported on Windows".to_string()) @@ -93,3 +156,182 @@ mod platform { } pub use platform::TpmStorage; + +#[cfg(test)] +#[allow(clippy::unwrap_used, clippy::panic)] +mod tests { + use super::*; + use enclaveapp_core::{Error, Result}; + use std::sync::atomic::{AtomicU64, Ordering}; + use std::sync::Mutex; + + static TEST_COUNTER: AtomicU64 = AtomicU64::new(0); + + fn test_dir() -> std::path::PathBuf { + let id = TEST_COUNTER.fetch_add(1, Ordering::SeqCst); + let pid = std::process::id(); + let dir = std::env::temp_dir().join(format!("awsenc-tpm-bridge-test-{pid}-{id}")); + std::fs::create_dir_all(&dir).unwrap(); + dir + } + + #[derive(Default)] + struct FakeState { + has_key: bool, + deleted: Vec, + generated: Vec<(String, KeyType, AccessPolicy)>, + } + + #[derive(Default)] + struct FakeEncryptor { + state: Mutex, + } + + impl FakeEncryptor { + fn with_existing_key() -> Self { + Self { + state: Mutex::new(FakeState { + has_key: true, + deleted: Vec::new(), + generated: Vec::new(), + }), + } + } + + fn deleted_labels(&self) -> Vec { + self.state.lock().unwrap().deleted.clone() + } + + fn generated_calls(&self) -> Vec<(String, KeyType, AccessPolicy)> { + self.state.lock().unwrap().generated.clone() + } + } + + impl EnclaveKeyManager for FakeEncryptor { + fn generate( + &self, + label: &str, + key_type: KeyType, + policy: AccessPolicy, + ) -> Result> { + let mut state = self.state.lock().map_err(|e| Error::KeyOperation { + operation: "lock".to_string(), + detail: e.to_string(), + })?; + state.has_key = true; + state.generated.push((label.to_string(), key_type, policy)); + Ok(vec![0x04; 65]) + } + + fn public_key(&self, label: &str) -> Result> { + let state = self.state.lock().map_err(|e| Error::KeyOperation { + operation: "lock".to_string(), + detail: e.to_string(), + })?; + if state.has_key { + Ok(vec![0x04; 65]) + } else { + Err(Error::KeyNotFound { + label: label.to_string(), + }) + } + } + + fn list_keys(&self) -> Result> { + Ok(Vec::new()) + } + + fn delete_key(&self, label: &str) -> Result<()> { + let mut state = self.state.lock().map_err(|e| Error::KeyOperation { + operation: "lock".to_string(), + detail: e.to_string(), + })?; + state.has_key = false; + state.deleted.push(label.to_string()); + Ok(()) + } + + fn is_available(&self) -> bool { + true + } + } + + impl EnclaveEncryptor for FakeEncryptor { + fn encrypt(&self, _label: &str, _plaintext: &[u8]) -> Result> { + Ok(Vec::new()) + } + + fn decrypt(&self, _label: &str, _ciphertext: &[u8]) -> Result> { + Ok(Vec::new()) + } + } + + #[test] + fn ensure_key_generates_when_missing() { + let dir = test_dir(); + let encryptor = FakeEncryptor::default(); + + ensure_key(&encryptor, &dir, "cache-key", true).unwrap(); + + assert!(encryptor.deleted_labels().is_empty()); + assert_eq!( + encryptor.generated_calls(), + vec![( + "cache-key".to_string(), + KeyType::Encryption, + AccessPolicy::BiometricOnly + )] + ); + + std::fs::remove_dir_all(&dir).unwrap(); + } + + #[test] + fn ensure_key_regenerates_when_policy_mismatches() { + let dir = test_dir(); + metadata::save_meta( + &dir, + "cache-key", + &metadata::KeyMeta::new("cache-key", KeyType::Encryption, AccessPolicy::None), + ) + .unwrap(); + let encryptor = FakeEncryptor::with_existing_key(); + + ensure_key(&encryptor, &dir, "cache-key", true).unwrap(); + + assert_eq!(encryptor.deleted_labels(), vec!["cache-key".to_string()]); + assert_eq!( + encryptor.generated_calls(), + vec![( + "cache-key".to_string(), + KeyType::Encryption, + AccessPolicy::BiometricOnly + )] + ); + + std::fs::remove_dir_all(&dir).unwrap(); + } + + #[test] + fn ensure_key_keeps_existing_key_when_policy_matches() { + let dir = test_dir(); + metadata::save_meta( + &dir, + "cache-key", + &metadata::KeyMeta::new( + "cache-key", + KeyType::Encryption, + AccessPolicy::BiometricOnly, + ), + ) + .unwrap(); + let encryptor = FakeEncryptor::with_existing_key(); + + ensure_key(&encryptor, &dir, "cache-key", true).unwrap(); + + assert!(encryptor.deleted_labels().is_empty()); + assert!(encryptor.generated_calls().is_empty()); + + std::fs::remove_dir_all(&dir).unwrap(); + } +}