diff --git a/codex-rs/agent-identity/src/lib.rs b/codex-rs/agent-identity/src/lib.rs index 7aad81a34f1..79d3dd1a23f 100644 --- a/codex-rs/agent-identity/src/lib.rs +++ b/codex-rs/agent-identity/src/lib.rs @@ -1,4 +1,6 @@ use std::collections::BTreeMap; +use std::error::Error as StdError; +use std::fmt; use std::time::Duration; use anyhow::Context; @@ -34,21 +36,23 @@ const AGENT_TASK_REGISTRATION_TIMEOUT: Duration = Duration::from_secs(30); const AGENT_IDENTITY_JWKS_TIMEOUT: Duration = Duration::from_secs(10); const AGENT_IDENTITY_JWT_AUDIENCE: &str = "codex-app-server"; const AGENT_IDENTITY_JWT_ISSUER: &str = "https://chatgpt.com/codex-backend/agent-identity"; - -/// Stored key material for a registered agent identity. +const AGENT_REGISTRATION_TIMEOUT: Duration = Duration::from_secs(15); +const PROD_AGENT_IDENTITY_AUTHAPI_BASE_URL: &str = "https://auth.openai.com/api/accounts"; +const STAGING_AGENT_IDENTITY_AUTHAPI_BASE_URL: &str = "https://auth.api.openai.org/api/accounts"; +const AGENT_IDENTITY_KEY_SEED_BYTES: usize = 64; +const AGENT_IDENTITY_KEY_DERIVATION_CONTEXT: &[u8] = b"codex-agent-identity-ed25519-v1"; + +/// Borrowed durable signing material for a registered agent identity. +/// +/// This intentionally does not include a task id. Task ids are scoped to a +/// single Codex run, while the agent runtime id and private key are the +/// reusable identity material used to register and sign that run task. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub struct AgentIdentityKey<'a> { pub agent_runtime_id: &'a str, pub private_key_pkcs8_base64: &'a str, } -/// Task binding to use when constructing a task-scoped AgentAssertion. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub struct AgentTaskAuthorizationTarget<'a> { - pub agent_runtime_id: &'a str, - pub task_id: &'a str, -} - #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] pub struct AgentBillOfMaterials { pub agent_version: String, @@ -103,23 +107,92 @@ struct RegisterTaskResponse { encrypted_task_id_camel: Option, } +#[derive(Debug, Serialize)] +struct RegisterAgentRequest { + abom: AgentBillOfMaterials, + agent_public_key: String, + capabilities: Vec, + ttl: Option, +} + +#[derive(Debug, Deserialize)] +struct RegisterAgentResponse { + agent_runtime_id: String, +} + +/// HTTP status failure returned by Agent Identity registration endpoints. +#[derive(Debug)] +pub struct AgentIdentityRegistrationHttpError { + operation: &'static str, + status: reqwest::StatusCode, + body: String, +} + +impl AgentIdentityRegistrationHttpError { + fn new(operation: &'static str, status: reqwest::StatusCode, body: String) -> Self { + Self { + operation, + status, + body, + } + } + + /// HTTP status returned by the registration endpoint. + pub fn status(&self) -> reqwest::StatusCode { + self.status + } +} + +impl fmt::Display for AgentIdentityRegistrationHttpError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.body.is_empty() { + write!(f, "{} failed with status {}", self.operation, self.status) + } else { + write!( + f, + "{} failed with status {}: {}", + self.operation, self.status, self.body + ) + } + } +} + +impl StdError for AgentIdentityRegistrationHttpError {} + +/// Returns whether an Agent Identity registration error is safe to retry. +pub fn is_retryable_registration_error(error: &anyhow::Error) -> bool { + error.chain().any(is_retryable_registration_cause) +} + +fn is_retryable_registration_cause(cause: &(dyn StdError + 'static)) -> bool { + if let Some(error) = cause.downcast_ref::() { + return is_retryable_registration_status(error.status()); + } + + if let Some(error) = cause.downcast_ref::() { + if let Some(status) = error.status() { + return is_retryable_registration_status(status); + } + return error.is_timeout() || error.is_connect() || error.is_request(); + } + + false +} + +fn is_retryable_registration_status(status: reqwest::StatusCode) -> bool { + status == reqwest::StatusCode::TOO_MANY_REQUESTS || status.is_server_error() +} + pub fn authorization_header_for_agent_task( key: AgentIdentityKey<'_>, - target: AgentTaskAuthorizationTarget<'_>, + task_id: &str, ) -> Result { - anyhow::ensure!( - key.agent_runtime_id == target.agent_runtime_id, - "agent task runtime {} does not match stored agent identity {}", - target.agent_runtime_id, - key.agent_runtime_id - ); - let timestamp = Utc::now().to_rfc3339_opts(SecondsFormat::Secs, true); let envelope = AgentAssertionEnvelope { - agent_runtime_id: target.agent_runtime_id.to_string(), - task_id: target.task_id.to_string(), + agent_runtime_id: key.agent_runtime_id.to_string(), + task_id: task_id.to_string(), timestamp: timestamp.clone(), - signature: sign_agent_assertion_payload(key, target.task_id, ×tamp)?, + signature: sign_agent_assertion_payload(key, task_id, ×tamp)?, }; let serialized_assertion = serialize_agent_assertion(&envelope)?; Ok(format!("AgentAssertion {serialized_assertion}")) @@ -127,10 +200,10 @@ pub fn authorization_header_for_agent_task( pub async fn fetch_agent_identity_jwks( client: &reqwest::Client, - chatgpt_base_url: &str, + agent_identity_jwt_base_url: &str, ) -> Result { let response = client - .get(agent_identity_jwks_url(chatgpt_base_url)) + .get(agent_identity_jwks_url(agent_identity_jwt_base_url)) .timeout(AGENT_IDENTITY_JWKS_TIMEOUT) .send() .await @@ -195,7 +268,7 @@ pub fn sign_task_registration_payload( pub async fn register_agent_task( client: &reqwest::Client, - chatgpt_base_url: &str, + agent_identity_authapi_base_url: &str, key: AgentIdentityKey<'_>, ) -> Result { let timestamp = Utc::now().to_rfc3339_opts(SecondsFormat::Secs, true); @@ -203,7 +276,7 @@ pub async fn register_agent_task( signature: sign_task_registration_payload(key, ×tamp)?, timestamp, }; - let url = agent_task_registration_url(chatgpt_base_url, key.agent_runtime_id); + let url = agent_task_registration_url(agent_identity_authapi_base_url, key.agent_runtime_id); let response = client .post(url) @@ -220,7 +293,12 @@ pub async fn register_agent_task( } else { body }; - anyhow::bail!("failed to register agent task with status {status}: {body}"); + return Err(AgentIdentityRegistrationHttpError::new( + "agent task registration", + status, + body, + ) + .into()); } let response = response @@ -231,6 +309,45 @@ pub async fn register_agent_task( task_id_from_register_task_response(key, response) } +pub async fn register_agent_identity( + client: &reqwest::Client, + agent_identity_authapi_base_url: &str, + access_token: &str, + is_fedramp_account: bool, + key_material: &GeneratedAgentKeyMaterial, + abom: AgentBillOfMaterials, + capabilities: Vec, +) -> Result { + let url = agent_registration_url(agent_identity_authapi_base_url); + let request = RegisterAgentRequest { + abom, + agent_public_key: key_material.public_key_ssh.clone(), + capabilities, + ttl: None, + }; + + let mut request_builder = client + .post(&url) + .bearer_auth(access_token) + .json(&request) + .timeout(AGENT_REGISTRATION_TIMEOUT); + if is_fedramp_account { + request_builder = request_builder.header("X-OpenAI-Fedramp", "true"); + } + + let response = request_builder + .send() + .await + .with_context(|| format!("failed to send agent identity registration request to {url}"))? + .error_for_status() + .with_context(|| format!("agent identity registration failed for {url}"))? + .json::() + .await + .with_context(|| format!("failed to parse agent identity response from {url}"))?; + + Ok(response.agent_runtime_id) +} + fn task_id_from_register_task_response( key: AgentIdentityKey<'_>, response: RegisterTaskResponse, @@ -260,10 +377,17 @@ pub fn decrypt_task_id_response( } pub fn generate_agent_key_material() -> Result { - let mut secret_key_bytes = [0u8; 32]; + let mut seed_material = [0u8; AGENT_IDENTITY_KEY_SEED_BYTES]; OsRng - .try_fill_bytes(&mut secret_key_bytes) - .context("failed to generate agent identity private key bytes")?; + .try_fill_bytes(&mut seed_material) + .context("failed to generate agent identity private key seed material")?; + // Ed25519 stores a 32-byte seed, so derive it from all sampled seed material. + let mut digest = Sha512::new(); + digest.update(AGENT_IDENTITY_KEY_DERIVATION_CONTEXT); + digest.update(seed_material); + let digest = digest.finalize(); + let mut secret_key_bytes = [0u8; 32]; + secret_key_bytes.copy_from_slice(&digest[..32]); let signing_key = SigningKey::from_bytes(&secret_key_bytes); let private_key_pkcs8 = signing_key .to_pkcs8_der() @@ -296,23 +420,22 @@ pub fn curve25519_secret_key_from_private_key_pkcs8_base64( Ok(curve25519_secret_key_from_signing_key(&signing_key)) } -pub fn agent_registration_url(chatgpt_base_url: &str) -> String { - let trimmed = chatgpt_base_url.trim_end_matches('/'); - format!("{trimmed}/v1/agent/register") +pub fn agent_registration_url(agent_identity_authapi_base_url: &str) -> String { + agent_identity_authapi_url(agent_identity_authapi_base_url, "/v1/agent/register") } -pub fn agent_task_registration_url(chatgpt_base_url: &str, agent_runtime_id: &str) -> String { - let trimmed = chatgpt_base_url.trim_end_matches('/'); - format!("{trimmed}/v1/agent/{agent_runtime_id}/task/register") +pub fn agent_task_registration_url( + agent_identity_authapi_base_url: &str, + agent_runtime_id: &str, +) -> String { + agent_identity_authapi_url( + agent_identity_authapi_base_url, + &format!("/v1/agent/{agent_runtime_id}/task/register"), + ) } -pub fn agent_identity_biscuit_url(chatgpt_base_url: &str) -> String { - let trimmed = chatgpt_base_url.trim_end_matches('/'); - format!("{trimmed}/authenticate_app_v2") -} - -pub fn agent_identity_jwks_url(chatgpt_base_url: &str) -> String { - let trimmed = chatgpt_base_url.trim_end_matches('/'); +pub fn agent_identity_jwks_url(agent_identity_jwt_base_url: &str) -> String { + let trimmed = agent_identity_jwt_base_url.trim_end_matches('/'); if trimmed.contains("/backend-api") { format!("{trimmed}/wham/agent-identities/jwks") } else { @@ -320,15 +443,59 @@ pub fn agent_identity_jwks_url(chatgpt_base_url: &str) -> String { } } -pub fn agent_identity_request_id() -> Result { - let mut request_id_bytes = [0u8; 16]; - OsRng - .try_fill_bytes(&mut request_id_bytes) - .context("failed to generate agent identity request id")?; - Ok(format!( - "codex-agent-identity-{}", - URL_SAFE_NO_PAD.encode(request_id_bytes) - )) +fn agent_identity_authapi_url(agent_identity_authapi_base_url: &str, api_path: &str) -> String { + let base_url = normalize_agent_identity_authapi_base_url(agent_identity_authapi_base_url); + format!("{base_url}{api_path}") +} + +pub fn agent_identity_authapi_base_url_from_chatgpt_base_url(chatgpt_base_url: &str) -> String { + let mut base_url = chatgpt_base_url.trim_end_matches('/').to_string(); + for suffix in [ + "/wham/remote/control/server/enroll", + "/wham/remote/control/server", + ] { + if let Some(stripped) = base_url.strip_suffix(suffix) { + base_url = stripped.to_string(); + break; + } + } + if matches!( + base_url.as_str(), + "https://chatgpt.com/codex" + | "https://chatgpt.com/backend-api/codex" + | "https://chat.openai.com/codex" + | "https://chat.openai.com/backend-api/codex" + | "https://chatgpt-staging.com/codex" + | "https://chatgpt-staging.com/backend-api/codex" + ) && let Some(stripped) = base_url.strip_suffix("/codex") + { + base_url = stripped.to_string(); + } + + match base_url.as_str() { + "https://chatgpt.com" | "https://chatgpt.com/backend-api" => { + PROD_AGENT_IDENTITY_AUTHAPI_BASE_URL.to_string() + } + "https://chat.openai.com" | "https://chat.openai.com/backend-api" => { + PROD_AGENT_IDENTITY_AUTHAPI_BASE_URL.to_string() + } + "https://chatgpt-staging.com" | "https://chatgpt-staging.com/backend-api" => { + STAGING_AGENT_IDENTITY_AUTHAPI_BASE_URL.to_string() + } + _ => normalize_agent_identity_authapi_base_url(&base_url), + } +} + +fn normalize_agent_identity_authapi_base_url(base_url: &str) -> String { + let trimmed = base_url.trim_end_matches('/'); + if matches!( + trimmed, + "https://auth.openai.com" | "https://auth.api.openai.org" + ) { + format!("{trimmed}/api/accounts") + } else { + trimmed.to_string() + } } pub fn build_abom(session_source: SessionSource) -> AgentBillOfMaterials { @@ -412,6 +579,24 @@ mod tests { use super::*; + #[test] + fn register_task_request_uses_single_run_task_shape() { + let request = RegisterTaskRequest { + timestamp: "2026-04-23T00:00:00Z".to_string(), + signature: "signature".to_string(), + }; + + let serialized = serde_json::to_value(request).expect("serialize request"); + + assert_eq!( + serialized, + serde_json::json!({ + "timestamp": "2026-04-23T00:00:00Z", + "signature": "signature", + }) + ); + } + #[test] fn authorization_header_for_agent_task_serializes_signed_agent_assertion() { let signing_key = SigningKey::from_bytes(&[7u8; 32]); @@ -422,13 +607,9 @@ mod tests { agent_runtime_id: "agent-123", private_key_pkcs8_base64: &BASE64_STANDARD.encode(private_key.as_bytes()), }; - let target = AgentTaskAuthorizationTarget { - agent_runtime_id: "agent-123", - task_id: "task-123", - }; - let header = - authorization_header_for_agent_task(key, target).expect("build agent assertion header"); + let header = authorization_header_for_agent_task(key, "task-123") + .expect("build agent assertion header"); let token = header .strip_prefix("AgentAssertion ") .expect("agent assertion scheme"); @@ -464,31 +645,6 @@ mod tests { .expect("signature should verify"); } - #[test] - fn authorization_header_for_agent_task_rejects_mismatched_runtime() { - let signing_key = SigningKey::from_bytes(&[7u8; 32]); - let private_key = signing_key - .to_pkcs8_der() - .expect("encode test key material"); - let private_key_pkcs8_base64 = BASE64_STANDARD.encode(private_key.as_bytes()); - let key = AgentIdentityKey { - agent_runtime_id: "agent-123", - private_key_pkcs8_base64: &private_key_pkcs8_base64, - }; - let target = AgentTaskAuthorizationTarget { - agent_runtime_id: "agent-456", - task_id: "task-123", - }; - - let error = authorization_header_for_agent_task(key, target) - .expect_err("runtime mismatch should fail"); - - assert_eq!( - error.to_string(), - "agent task runtime agent-456 does not match stored agent identity agent-123" - ); - } - #[test] fn decode_agent_identity_jwt_reads_claims() { let jwt = jwt_with_payload(serde_json::json!({ @@ -704,7 +860,117 @@ J1bwkqKZTB5dHolX9A58e/xXnfZ5P8f3Z83+Izap3FwqQulk7b1WO1MQcHuVg2NN } #[test] - fn agent_identity_jwks_url_uses_backend_api_base_url() { + fn agent_identity_authapi_base_url_from_chatgpt_base_url_uses_public_authapi() { + assert_eq!( + agent_identity_authapi_base_url_from_chatgpt_base_url("https://chatgpt.com/codex"), + "https://auth.openai.com/api/accounts" + ); + assert_eq!( + agent_identity_authapi_base_url_from_chatgpt_base_url( + "https://chatgpt.com/backend-api/codex" + ), + "https://auth.openai.com/api/accounts" + ); + assert_eq!( + agent_identity_authapi_base_url_from_chatgpt_base_url( + "https://chatgpt-staging.com/backend-api" + ), + "https://auth.api.openai.org/api/accounts" + ); + } + + #[test] + fn agent_identity_authapi_base_url_from_chatgpt_base_url_normalizes_authapi_root() { + assert_eq!( + agent_identity_authapi_base_url_from_chatgpt_base_url("https://auth.openai.com"), + "https://auth.openai.com/api/accounts" + ); + assert_eq!( + agent_identity_authapi_base_url_from_chatgpt_base_url( + "https://auth.api.openai.org/api/accounts/" + ), + "https://auth.api.openai.org/api/accounts" + ); + } + + #[test] + fn agent_identity_authapi_base_url_from_chatgpt_base_url_preserves_local_base() { + assert_eq!( + agent_identity_authapi_base_url_from_chatgpt_base_url("http://localhost:8080"), + "http://localhost:8080" + ); + assert_eq!( + agent_identity_authapi_base_url_from_chatgpt_base_url( + "http://localhost:8080/api/codex" + ), + "http://localhost:8080/api/codex" + ); + } + + #[test] + fn agent_registration_url_appends_to_authapi_base_url() { + assert_eq!( + agent_registration_url("https://auth.openai.com/api/accounts"), + "https://auth.openai.com/api/accounts/v1/agent/register" + ); + assert_eq!( + agent_registration_url("http://localhost:8080"), + "http://localhost:8080/v1/agent/register" + ); + assert_eq!( + agent_registration_url("http://localhost:8080/backend-api"), + "http://localhost:8080/backend-api/v1/agent/register" + ); + } + + #[test] + fn agent_task_registration_url_appends_to_authapi_base_url() { + assert_eq!( + agent_task_registration_url("https://auth.openai.com/api/accounts", "agent-runtime-id"), + "https://auth.openai.com/api/accounts/v1/agent/agent-runtime-id/task/register" + ); + assert_eq!( + agent_task_registration_url("https://auth.openai.com", "agent-runtime-id"), + "https://auth.openai.com/api/accounts/v1/agent/agent-runtime-id/task/register" + ); + assert_eq!( + agent_task_registration_url("http://localhost:8080", "agent-runtime-id"), + "http://localhost:8080/v1/agent/agent-runtime-id/task/register" + ); + } + + #[test] + fn retryable_registration_error_accepts_429_and_5xx() { + let too_many_requests = anyhow::Error::new(AgentIdentityRegistrationHttpError::new( + "agent registration", + reqwest::StatusCode::TOO_MANY_REQUESTS, + "rate limited".to_string(), + )); + let unavailable = anyhow::Error::new(AgentIdentityRegistrationHttpError::new( + "agent registration", + reqwest::StatusCode::SERVICE_UNAVAILABLE, + "try later".to_string(), + )); + + assert!(is_retryable_registration_error(&too_many_requests)); + assert!(is_retryable_registration_error(&unavailable)); + } + + #[test] + fn retryable_registration_error_rejects_hard_failures() { + let forbidden = anyhow::Error::new(AgentIdentityRegistrationHttpError::new( + "agent registration", + reqwest::StatusCode::FORBIDDEN, + "not allowed".to_string(), + )); + let malformed = anyhow::anyhow!("failed to sign registration request"); + + assert!(!is_retryable_registration_error(&forbidden)); + assert!(!is_retryable_registration_error(&malformed)); + } + + #[test] + fn agent_identity_jwks_url_uses_agent_identity_jwt_route() { assert_eq!( agent_identity_jwks_url("https://chatgpt.com/backend-api"), "https://chatgpt.com/backend-api/wham/agent-identities/jwks" @@ -716,7 +982,7 @@ J1bwkqKZTB5dHolX9A58e/xXnfZ5P8f3Z83+Izap3FwqQulk7b1WO1MQcHuVg2NN } #[test] - fn agent_identity_jwks_url_uses_codex_api_base_url() { + fn agent_identity_jwks_url_uses_jwt_issuer_base_url() { assert_eq!( agent_identity_jwks_url("http://localhost:8080/api/codex"), "http://localhost:8080/api/codex/agent-identities/jwks" diff --git a/codex-rs/login/src/auth/agent_identity.rs b/codex-rs/login/src/auth/agent_identity.rs index 3644713328f..c99f7e4d85c 100644 --- a/codex-rs/login/src/auth/agent_identity.rs +++ b/codex-rs/login/src/auth/agent_identity.rs @@ -1,43 +1,61 @@ +use std::sync::Arc; + use codex_agent_identity::AgentIdentityKey; +use codex_agent_identity::agent_identity_authapi_base_url_from_chatgpt_base_url; use codex_agent_identity::register_agent_task; use codex_protocol::account::PlanType as AccountPlanType; -use std::env; +use tokio::sync::OnceCell; use crate::default_client::build_reqwest_client; use super::storage::AgentIdentityAuthRecord; -const PROD_AGENT_IDENTITY_AUTHAPI_BASE_URL: &str = "https://auth.openai.com/api/accounts"; -const CODEX_AGENT_IDENTITY_AUTHAPI_BASE_URL_ENV_VAR: &str = "CODEX_AGENT_IDENTITY_AUTHAPI_BASE_URL"; +const DEFAULT_CHATGPT_BACKEND_BASE_URL: &str = "https://chatgpt.com/backend-api"; -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct AgentIdentityAuth { record: AgentIdentityAuthRecord, - process_task_id: String, + run_task_id: Arc>, +} + +impl Clone for AgentIdentityAuth { + fn clone(&self) -> Self { + Self { + record: self.record.clone(), + run_task_id: Arc::clone(&self.run_task_id), + } + } } impl AgentIdentityAuth { - pub async fn load(record: AgentIdentityAuthRecord) -> std::io::Result { - let agent_identity_authapi_base_url = agent_identity_authapi_base_url(); - let process_task_id = register_agent_task( - &build_reqwest_client(), - &agent_identity_authapi_base_url, - key(&record), - ) - .await - .map_err(std::io::Error::other)?; - Ok(Self { + pub fn new(record: AgentIdentityAuthRecord) -> Self { + Self { record, - process_task_id, - }) + run_task_id: Arc::new(OnceCell::new()), + } } pub fn record(&self) -> &AgentIdentityAuthRecord { &self.record } - pub fn process_task_id(&self) -> &str { - &self.process_task_id + pub fn run_task_id(&self) -> Option { + self.run_task_id.get().cloned() + } + + pub async fn ensure_run_task(&self, chatgpt_base_url: Option) -> std::io::Result<()> { + self.run_task_id_for(chatgpt_base_url).await.map(|_| ()) + } + + pub async fn register_task(&self, chatgpt_base_url: Option) -> std::io::Result { + let authapi_base_url = agent_identity_authapi_base_url_from_chatgpt_base_url( + chatgpt_base_url + .as_deref() + .unwrap_or(DEFAULT_CHATGPT_BACKEND_BASE_URL), + ); + register_agent_task(&build_reqwest_client(), &authapi_base_url, self.key()) + .await + .map_err(std::io::Error::other) } pub fn account_id(&self) -> &str { @@ -59,82 +77,148 @@ impl AgentIdentityAuth { pub fn is_fedramp_account(&self) -> bool { self.record.chatgpt_account_is_fedramp } -} - -fn agent_identity_authapi_base_url() -> String { - env::var(CODEX_AGENT_IDENTITY_AUTHAPI_BASE_URL_ENV_VAR) - .ok() - .map(|base_url| base_url.trim().trim_end_matches('/').to_string()) - .filter(|base_url| !base_url.is_empty()) - .unwrap_or_else(|| PROD_AGENT_IDENTITY_AUTHAPI_BASE_URL.to_string()) -} + fn key(&self) -> AgentIdentityKey<'_> { + AgentIdentityKey { + agent_runtime_id: &self.record.agent_runtime_id, + private_key_pkcs8_base64: &self.record.agent_private_key, + } + } -fn key(record: &AgentIdentityAuthRecord) -> AgentIdentityKey<'_> { - AgentIdentityKey { - agent_runtime_id: &record.agent_runtime_id, - private_key_pkcs8_base64: &record.agent_private_key, + async fn run_task_id_for(&self, chatgpt_base_url: Option) -> std::io::Result { + self.run_task_id + .get_or_try_init(|| async { self.register_task(chatgpt_base_url).await }) + .await + .cloned() } } #[cfg(test)] mod tests { + use std::sync::Arc; + use std::sync::atomic::AtomicUsize; + use std::sync::atomic::Ordering; + + use codex_agent_identity::generate_agent_key_material; + use pretty_assertions::assert_eq; + use serde_json::json; + use wiremock::Mock; + use wiremock::MockServer; + use wiremock::ResponseTemplate; + use wiremock::matchers::method; + use wiremock::matchers::path; + use super::*; - use serial_test::serial; - #[test] - #[serial(codex_auth_env)] - fn agent_identity_authapi_base_url_prefers_env_value() { - let _guard = EnvVarGuard::set( - CODEX_AGENT_IDENTITY_AUTHAPI_BASE_URL_ENV_VAR, - "https://authapi.example.test/api/accounts/", - ); - assert_eq!( - agent_identity_authapi_base_url(), - "https://authapi.example.test/api/accounts" - ); + fn agent_identity_record(private_key: String) -> AgentIdentityAuthRecord { + AgentIdentityAuthRecord { + agent_runtime_id: "agent-runtime-1".to_string(), + agent_private_key: private_key, + account_id: "account-1".to_string(), + chatgpt_user_id: "user-1".to_string(), + email: "agent@example.com".to_string(), + plan_type: AccountPlanType::Plus, + chatgpt_account_is_fedramp: false, + } } - #[test] - #[serial(codex_auth_env)] - fn agent_identity_authapi_base_url_uses_prod_authapi_by_default() { - let _guard = EnvVarGuard::remove(CODEX_AGENT_IDENTITY_AUTHAPI_BASE_URL_ENV_VAR); - assert_eq!( - agent_identity_authapi_base_url(), - PROD_AGENT_IDENTITY_AUTHAPI_BASE_URL - ); + fn agent_identity_auth() -> AgentIdentityAuth { + let key_material = generate_agent_key_material().expect("generate key material"); + AgentIdentityAuth::new(agent_identity_record(key_material.private_key_pkcs8_base64)) } - struct EnvVarGuard { - key: &'static str, - original: Option, + #[tokio::test] + async fn ensure_run_task_registers_once() -> anyhow::Result<()> { + let auth = agent_identity_auth(); + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/agent/agent-runtime-1/task/register")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "task_id": "task-run-1", + }))) + .expect(1) + .mount(&server) + .await; + + auth.ensure_run_task(Some(server.uri())).await?; + auth.ensure_run_task(Some(server.uri())).await?; + + assert_eq!(auth.run_task_id(), Some("task-run-1".to_string())); + let requests = server + .received_requests() + .await + .expect("failed to fetch task registration request"); + let request_body = requests[0] + .body_json::() + .expect("task registration request should be JSON"); + let request_body = request_body + .as_object() + .expect("request body should be object"); + assert!(request_body.get("timestamp").is_some()); + assert!(request_body.get("signature").is_some()); + assert_eq!(request_body.len(), 2); + Ok(()) } - impl EnvVarGuard { - fn set(key: &'static str, value: &str) -> Self { - let original = env::var_os(key); - unsafe { - env::set_var(key, value); - } - Self { key, original } - } - - fn remove(key: &'static str) -> Self { - let original = env::var_os(key); - unsafe { - env::remove_var(key); - } - Self { key, original } - } + #[tokio::test] + async fn run_task_is_shared_across_clones() -> anyhow::Result<()> { + let auth = agent_identity_auth(); + let cloned = auth.clone(); + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/agent/agent-runtime-1/task/register")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "task_id": "task-run-1", + }))) + .expect(1) + .mount(&server) + .await; + + auth.ensure_run_task(Some(server.uri())).await?; + cloned.ensure_run_task(Some(server.uri())).await?; + + assert_eq!(cloned.run_task_id(), Some("task-run-1".to_string())); + Ok(()) } - impl Drop for EnvVarGuard { - fn drop(&mut self) { - unsafe { - match &self.original { - Some(value) => env::set_var(self.key, value), - None => env::remove_var(self.key), + #[tokio::test] + async fn failed_run_task_registration_can_retry() -> anyhow::Result<()> { + let auth = agent_identity_auth(); + let server = MockServer::start().await; + let request_count = Arc::new(AtomicUsize::new(0)); + let response_count = Arc::clone(&request_count); + Mock::given(method("POST")) + .and(path("/v1/agent/agent-runtime-1/task/register")) + .respond_with(move |_request: &wiremock::Request| { + if response_count.fetch_add(1, Ordering::SeqCst) == 0 { + ResponseTemplate::new(500) + } else { + ResponseTemplate::new(200).set_body_json(json!({ + "task_id": "task-run-1", + })) } - } - } + }) + .expect(2) + .mount(&server) + .await; + + auth.ensure_run_task(Some(server.uri())) + .await + .expect_err("first registration should fail"); + auth.ensure_run_task(Some(server.uri())).await?; + + assert_eq!(request_count.load(Ordering::SeqCst), 2); + assert_eq!(auth.run_task_id(), Some("task-run-1".to_string())); + Ok(()) + } + + #[test] + fn run_task_id_is_shared_across_clones() { + let auth = agent_identity_auth(); + let cloned = auth.clone(); + auth.run_task_id + .set("task-run-1".to_string()) + .expect("run task should be unset"); + + assert_eq!(cloned.run_task_id(), Some("task-run-1".to_string())); } } diff --git a/codex-rs/login/src/auth/auth_tests.rs b/codex-rs/login/src/auth/auth_tests.rs index e7b5502ebcd..d69712795e1 100644 --- a/codex-rs/login/src/auth/auth_tests.rs +++ b/codex-rs/login/src/auth/auth_tests.rs @@ -819,8 +819,6 @@ async fn load_auth_reads_access_token_from_env() { let _access_token_guard = EnvVarGuard::set(CODEX_ACCESS_TOKEN_ENV_VAR, &agent_identity); let chatgpt_base_url = format!("{}/backend-api", server.uri()); - let _authapi_guard = - EnvVarGuard::set("CODEX_AGENT_IDENTITY_AUTHAPI_BASE_URL", &chatgpt_base_url); let auth = super::load_auth( codex_home.path(), /*enable_codex_api_key_env*/ false, @@ -835,7 +833,7 @@ async fn load_auth_reads_access_token_from_env() { panic!("env auth should load as agent identity"); }; assert_eq!(agent_identity.record(), &expected_record); - assert_eq!(agent_identity.process_task_id(), "task-123"); + assert_eq!(agent_identity.run_task_id(), Some("task-123".to_string())); assert!( !get_auth_file(codex_home.path()).exists(), "env auth should not write auth.json" @@ -1101,8 +1099,6 @@ async fn enforce_login_restrictions_logs_out_for_agent_identity_workspace_mismat .mount(&server) .await; let chatgpt_base_url = format!("{}/backend-api", server.uri()); - let _authapi_guard = - EnvVarGuard::set("CODEX_AGENT_IDENTITY_AUTHAPI_BASE_URL", &chatgpt_base_url); save_auth( codex_home.path(), &AuthDotJson { @@ -1338,8 +1334,6 @@ async fn assert_agent_identity_plan_alias( .mount(&server) .await; let chatgpt_base_url = format!("{}/backend-api", server.uri()); - let _authapi_guard = - EnvVarGuard::set("CODEX_AGENT_IDENTITY_AUTHAPI_BASE_URL", &chatgpt_base_url); let auth = CodexAuth::from_agent_identity_jwt(&jwt, Some(&chatgpt_base_url)) .await .expect("agent identity auth"); diff --git a/codex-rs/login/src/auth/manager.rs b/codex-rs/login/src/auth/manager.rs index 9e255a99a3e..5d2ac990a5e 100644 --- a/codex-rs/login/src/auth/manager.rs +++ b/codex-rs/login/src/auth/manager.rs @@ -281,7 +281,9 @@ impl CodexAuth { .trim_end_matches('/') .to_string(); let record = verified_agent_identity_record(jwt, &base_url).await?; - Ok(Self::AgentIdentity(AgentIdentityAuth::load(record).await?)) + let auth = AgentIdentityAuth::new(record); + auth.ensure_run_task(Some(base_url)).await?; + Ok(Self::AgentIdentity(auth)) } pub async fn from_personal_access_token(access_token: &str) -> std::io::Result { diff --git a/codex-rs/model-provider/src/auth.rs b/codex-rs/model-provider/src/auth.rs index 4f87d46baec..61727a8fb8f 100644 --- a/codex-rs/model-provider/src/auth.rs +++ b/codex-rs/model-provider/src/auth.rs @@ -1,7 +1,6 @@ use std::sync::Arc; use codex_agent_identity::AgentIdentityKey; -use codex_agent_identity::AgentTaskAuthorizationTarget; use codex_agent_identity::authorization_header_for_agent_task; use codex_api::AuthProvider; use codex_api::SharedAuthProvider; @@ -21,17 +20,20 @@ struct AgentIdentityAuthProvider { impl AuthProvider for AgentIdentityAuthProvider { fn add_auth_headers(&self, headers: &mut HeaderMap) { let record = self.auth.record(); - let header_value = authorization_header_for_agent_task( - AgentIdentityKey { - agent_runtime_id: &record.agent_runtime_id, - private_key_pkcs8_base64: &record.agent_private_key, - }, - AgentTaskAuthorizationTarget { - agent_runtime_id: &record.agent_runtime_id, - task_id: self.auth.process_task_id(), - }, - ) - .map_err(std::io::Error::other); + let header_value = self + .auth + .run_task_id() + .ok_or_else(|| std::io::Error::other("agent identity run task is not initialized")) + .and_then(|task_id| { + authorization_header_for_agent_task( + AgentIdentityKey { + agent_runtime_id: &record.agent_runtime_id, + private_key_pkcs8_base64: &record.agent_private_key, + }, + &task_id, + ) + .map_err(std::io::Error::other) + }); if let Ok(header_value) = header_value && let Ok(header) = HeaderValue::from_str(&header_value)