Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions MODULE.bazel.lock

Large diffs are not rendered by default.

90 changes: 89 additions & 1 deletion codex-rs/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions codex-rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ async-stream = "0.3.6"
async-trait = "0.1.89"
aws-config = "1"
aws-credential-types = "1"
aws-lc-rs = { version = "=1.16.2", default-features = false, features = ["non-fips"] }
aws-sigv4 = "1"
aws-types = "1"
axum = { version = "0.8", default-features = false }
Expand All @@ -271,6 +272,13 @@ chardetng = "0.1.17"
chrono = "0.4.43"
clap = "4"
clap_complete = "4"
clatter = { version = "2.2.0", default-features = false, features = [
"alloc",
"getrandom",
"use-25519",
"use-aes-gcm",
"use-sha",
] }
color-eyre = "0.6.3"
constant_time_eq = "0.3.1"
crossbeam-channel = "0.5.15"
Expand Down
2 changes: 2 additions & 0 deletions codex-rs/exec-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ workspace = true
[dependencies]
arc-swap = { workspace = true }
async-trait = { workspace = true }
aws-lc-rs = { workspace = true }
axum = { workspace = true, features = ["http1", "tokio", "ws"] }
base64 = { workspace = true }
bytes = { workspace = true }
clatter = { workspace = true }
codex-app-server-protocol = { workspace = true }
codex-api = { workspace = true }
codex-client = { workspace = true }
Expand Down
90 changes: 90 additions & 0 deletions codex-rs/exec-server/src/aws_lc_ml_kem.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
use aws_lc_rs::kem::Ciphertext;
use aws_lc_rs::kem::DecapsulationKey;
use aws_lc_rs::kem::EncapsulationKey;
use aws_lc_rs::kem::ML_KEM_768;
use clatter::KeyPair;
use clatter::bytearray::ByteArray;
use clatter::bytearray::HeapArray;
use clatter::bytearray::SensitiveByteArray;
use clatter::error::KemError;
use clatter::error::KemResult;
use clatter::traits::CryptoComponent;
use clatter::traits::Kem;
use clatter::traits::Rng;

pub(super) const PUBLIC_KEY_LEN: usize = 1184;
const SECRET_KEY_LEN: usize = 2400;
const CIPHERTEXT_LEN: usize = 1088;
const SHARED_SECRET_LEN: usize = 32;

/// ML-KEM-768 implementation backed by AWS-LC through `aws-lc-rs`.
#[derive(Clone)]
pub(super) struct AwsLcMlKem768;

impl CryptoComponent for AwsLcMlKem768 {
fn name() -> &'static str {
"MLKEM768"
}
}

impl Kem for AwsLcMlKem768 {
type SecretKey = SensitiveByteArray<HeapArray<SECRET_KEY_LEN>>;
type PubKey = HeapArray<PUBLIC_KEY_LEN>;
type Ct = HeapArray<CIPHERTEXT_LEN>;
type Ss = SensitiveByteArray<[u8; SHARED_SECRET_LEN]>;

fn genkey_rng<R: Rng>(_rng: &mut R) -> KemResult<KeyPair<Self::PubKey, Self::SecretKey>> {
// AWS-LC owns ML-KEM key-generation randomness internally, so
// Clatter's injectable RNG cannot be plumbed through this provider.
let decapsulation_key =
DecapsulationKey::generate(&ML_KEM_768).map_err(|_| KemError::KeyGeneration)?;
let encapsulation_key = decapsulation_key
.encapsulation_key()
.map_err(|_| KemError::KeyGeneration)?;
let public = encapsulation_key
.key_bytes()
.map_err(|_| KemError::KeyGeneration)?;
let secret = decapsulation_key
.key_bytes()
.map_err(|_| KemError::KeyGeneration)?;

Ok(KeyPair {
public: Self::PubKey::from_slice(public.as_ref()),
secret: Self::SecretKey::from_slice(secret.as_ref()),
})
}

fn encapsulate<R: Rng>(pk: &[u8], _rng: &mut R) -> KemResult<(Self::Ct, Self::Ss)> {
let encapsulation_key =
EncapsulationKey::new(&ML_KEM_768, pk).map_err(|_| KemError::Input)?;
let (ciphertext, shared_secret) = encapsulation_key
.encapsulate()
.map_err(|_| KemError::Encapsulation)?;

Ok((
Self::Ct::from_slice(ciphertext.as_ref()),
Self::Ss::from_slice(shared_secret.as_ref()),
))
}

fn decapsulate(ct: &[u8], sk: &[u8]) -> KemResult<Self::Ss> {
// Reject the length before constructing AWS-LC's ciphertext wrapper.
// This keeps malformed wire input classified as an input error rather
// than relying on provider-specific decapsulation behavior.
if ct.len() != CIPHERTEXT_LEN {
return Err(KemError::Input);
}

let decapsulation_key =
DecapsulationKey::new(&ML_KEM_768, sk).map_err(|_| KemError::Input)?;
let shared_secret = decapsulation_key
.decapsulate(Ciphertext::from(ct))
.map_err(|_| KemError::Decapsulation)?;

Ok(Self::Ss::from_slice(shared_secret.as_ref()))
}
}

#[cfg(test)]
#[path = "aws_lc_ml_kem_tests.rs"]
mod tests;
31 changes: 31 additions & 0 deletions codex-rs/exec-server/src/aws_lc_ml_kem_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use clatter::bytearray::ByteArray;
use clatter::traits::Kem;
use pretty_assertions::assert_eq;

use super::AwsLcMlKem768;

#[test]
fn kem_roundtrip() {
let keypair = AwsLcMlKem768::genkey().expect("generate keypair");
let mut rng = clatter::crypto::rng::DefaultRng;
let (ciphertext, encapsulated_secret) =
AwsLcMlKem768::encapsulate(keypair.public.as_slice(), &mut rng).expect("encapsulate");
let decapsulated_secret =
AwsLcMlKem768::decapsulate(ciphertext.as_slice(), keypair.secret.as_slice())
.expect("decapsulate");

assert_eq!(
encapsulated_secret.as_slice(),
decapsulated_secret.as_slice()
);
}

#[test]
fn decapsulate_rejects_wrong_ciphertext_length() {
let keypair = AwsLcMlKem768::genkey().expect("generate keypair");

let error = AwsLcMlKem768::decapsulate(&[], keypair.secret.as_slice())
.expect_err("empty ciphertext should be rejected");

assert!(matches!(error, clatter::error::KemError::Input));
}
5 changes: 5 additions & 0 deletions codex-rs/exec-server/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod aws_lc_ml_kem;
mod client;
mod client_api;
mod client_transport;
Expand All @@ -11,6 +12,7 @@ mod fs_helper_main;
mod fs_sandbox;
mod local_file_system;
mod local_process;
mod noise_channel;
mod process;
mod process_id;
mod protocol;
Expand Down Expand Up @@ -51,6 +53,9 @@ pub use fs_helper::CODEX_FS_HELPER_ARG1;
pub use fs_helper_main::main as run_fs_helper_main;
pub use local_file_system::LOCAL_FS;
pub use local_file_system::LocalFileSystem;
pub use noise_channel::NoiseChannelError;
pub use noise_channel::NoiseChannelIdentity;
pub use noise_channel::NoiseChannelPublicKey;
pub use process::ExecBackend;
pub use process::ExecProcess;
pub use process::ExecProcessEvent;
Expand Down
Loading
Loading