Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 130 additions & 78 deletions crates/lib/src/api/client/tree.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,26 @@
use async_compression::futures::bufread::GzipDecoder;
use async_tar::Archive;
use flate2::Compression;
use flate2::write::GzEncoder;
use futures_util::TryStreamExt;
use std::collections::HashSet;
use std::path::PathBuf;
use std::sync::Arc;
use std::time;
use tempfile::TempDir;
use tokio_util::io::{ReaderStream, StreamReader, SyncIoBridge};

use crate::api;
use crate::api::client;
use crate::constants::{NODES_DIR, OXEN_HIDDEN_DIR, TREE_DIR};
use crate::core::db::merkle_node::file_backend;
use crate::core::progress::push_progress::PushProgress;
use crate::core::v_latest::index::CommitMerkleTree;
use crate::error::OxenError;
use crate::model::merkle_tree::merkle_transport::{
MerklePacker, MerkleUnpacker, PackOptions, UnpackOptions,
};
use crate::model::merkle_tree::node::MerkleTreeNode;
use crate::model::{LocalRepository, MerkleHash, RemoteRepository};
use crate::opts::download_tree_opts::DownloadTreeOpts;
use crate::opts::fetch_opts::FetchOpts;
use crate::view::tree::MerkleHashResponse;
use crate::view::tree::merkle_hashes::MerkleHashes;
use crate::view::{MerkleHashesResponse, StatusMessage};
use crate::{api, util};

/// Check if a node exists in the remote repository merkle tree by hash
#[tracing::instrument(skip_all)]
Expand Down Expand Up @@ -50,61 +49,71 @@ pub async fn has_node(
}
}

/// Upload a node to the remote repository merkle tree
/// Upload a set of merkle nodes to the remote repository.
///
/// Packs the nodes into the canonical tar-gz wire format and streams the bytes straight
/// into the HTTP upload body — no intermediate `Vec<u8>` is materialized. The pack runs
/// on a blocking worker (`tokio::task::spawn_blocking`) that writes into one end of a
/// `tokio::io::duplex`; the HTTP body reads from the other end through `ReaderStream`,
/// so upload and pack progress together with back-pressure.
pub async fn create_nodes(
local_repo: &LocalRepository,
remote_repo: &RemoteRepository,
nodes: HashSet<MerkleHash>,
progress: &Arc<PushProgress>,
) -> Result<(), OxenError> {
// Compress the node
log::debug!("create_nodes starting compression");
// OPT: Try Compression::fast();
let enc = GzEncoder::new(Vec::new(), Compression::default());
log::debug!("create_nodes compressing nodes");
let mut tar = tar::Builder::new(enc);
log::debug!("create_nodes creating tar");
let node_path = local_repo
.path
.join(OXEN_HIDDEN_DIR)
.join(TREE_DIR)
.join(NODES_DIR);

for (i, node_hash) in nodes.iter().enumerate() {
let dir_prefix = node_hash.to_hex_hash().node_db_prefix();
let node_dir = node_path.join(&dir_prefix);
// log::debug!(
// "create_nodes appending objects dir {:?} to tar at path {:?}",
// dir_prefix,
// node_dir
// );
progress.set_message(format!("Packing {}/{} nodes", i + 1, nodes.len()));

log::debug!("create_nodes appending dir to tar");
tar.append_dir_all(dir_prefix, node_dir)?;
}

tar.finish()?;
log::debug!("create_nodes finished tar");
let buffer: Vec<u8> = tar.into_inner()?.finish()?;
let n = nodes.len();
progress.set_message(format!("Pushing {n} nodes"));

// Extend the progress bar's total length by an uncompressed-bytes estimate of the
// tarball so the upload phase has a known end and a meaningful ETA. Random-ish
// merkle hash bytes compress to ~1.0×, so the uncompressed estimate is a tight
// upper bound on the bytes that will actually flow over the wire.
let estimated_upload_bytes = file_backend::pack_nodes_byte_estimate(local_repo, &nodes);
progress.inc_total_bytes(estimated_upload_bytes);

// Pack -> duplex writer (sync) -> duplex reader (async) -> HTTP body stream.
// 64 KiB duplex buffer mirrors the server-side streaming pattern in
// `crates/server/src/controllers/versions.rs`.
let (async_writer, async_reader) = tokio::io::duplex(64 * 1024);
let repo = local_repo.clone();
let pack_handle = tokio::task::spawn_blocking(move || -> Result<(), OxenError> {
let sync_writer = SyncIoBridge::new(async_writer);
// Legacy client-push wire format: required so older `oxen-server` deployments
// (which pre-pend `tree/nodes/` server-side at install time) install entries
// at the right paths.
repo.merkle_store()
.pack_nodes(&nodes, PackOptions::LegacyClientPush, sync_writer)?;
Ok(())
});

// Tick `progress` per chunk so the user sees upload progress moving.
let progress_for_stream = Arc::clone(progress);
let body_stream = ReaderStream::new(async_reader).inspect_ok(move |chunk| {
progress_for_stream.add_bytes(chunk.len() as u64);
});

// Upload the node
let uri = "/tree/nodes".to_string();
let url = api::endpoint::url_from_repo(remote_repo, &uri)?;
let client = client::builder_for_url(&url)?
.timeout(time::Duration::from_secs(120))
.build()?;
log::debug!("uploading {n} nodes to {url}");

let size = buffer.len() as u64;
log::debug!(
"uploading node of size {} to {}",
bytesize::ByteSize::b(size),
url
);
let res = client.post(&url).body(buffer.to_owned()).send().await?;
let res = client
.post(&url)
.body(reqwest::Body::wrap_stream(body_stream))
.send()
.await?;
let body = client::parse_json_body(&url, res).await?;
log::debug!("upload node complete {body}");

// Surface any pack error after the upload completes (the duplex reader reaching EOF
// signals pack end-of-stream; panics and Result::Err come through the join handle).
pack_handle
.await
.map_err(|e| OxenError::basic_str(format!("pack task panicked: {e}")))??;

Ok(())
}

Expand Down Expand Up @@ -330,6 +339,13 @@ pub async fn download_trees_between(
Ok(())
}

/// Download a merkle-tree tarball from the remote repository and unpack it into the
/// local store. Streams the response body straight into the `MerkleUnpacker` so nothing
/// buffers the whole payload in memory.
///
/// The VFS branch is preserved but no longer lives here: `FileBackend::unpack` handles
/// the `is_vfs` case internally (tempdir + `copy_dir_all` dance). That keeps the client
/// logic generic across backends.
async fn node_download_request(
local_repo: &LocalRepository,
url: impl AsRef<str>,
Expand All @@ -339,41 +355,25 @@ async fn node_download_request(
let client = client::builder_for_url(url)?
.timeout(time::Duration::from_secs(12000))
.build()?;
log::debug!("node_download_request about to send request {url}");
log::debug!("node_download_request sending request {url}");
let res = client.get(url).send().await?;
let res = client::handle_non_json_response(url, res).await?;

// The remote tar packs it in TREE_DIR/NODES_DIR
// So this will unpack it in OXEN_HIDDEN_DIR/TREE_DIR/NODES_DIR
let full_unpacked_path = local_repo.path.join(OXEN_HIDDEN_DIR);

// Create the temp path if it doesn't exist
util::fs::create_dir_all(&full_unpacked_path)?;

let reader = res
.bytes_stream()
.map_err(futures::io::Error::other)
.into_async_read();

let decoder = GzipDecoder::new(futures::io::BufReader::new(reader));
let archive = Archive::new(decoder);

// If the repo is stored on a virtual file system, re-route the nodes through a temp dir
if local_repo.is_vfs() {
let temp_dir = TempDir::new()?;
let temp_path = temp_dir.path();

// Unpack the tar in a temp dir
log::debug!("node_download_request unpacking to {temp_path:?}");
util::fs::unpack_async_tar_archive(archive, temp_path).await?;
log::debug!("Succesfully unpacked tar to temp dir");

// Copy to the repo
util::fs::copy_dir_all(&temp_dir, &full_unpacked_path)?;
} else {
// Else, unpack directly to the repo
util::fs::unpack_async_tar_archive(archive, &full_unpacked_path).await?;
}
// async Stream<Item = Result<Bytes, _>> → AsyncRead → sync Read, bridged across
// the spawn_blocking boundary so the sync trait consumes streamed bytes incrementally.
let async_reader = StreamReader::new(res.bytes_stream().map_err(std::io::Error::other));
let sync_reader = SyncIoBridge::new(async_reader);

let repo = local_repo.clone();
tokio::task::spawn_blocking(move || -> Result<(), OxenError> {
// Download path: overwrite existing files on disk, matching `main`'s
// `util::fs::unpack_async_tar_archive` behaviour.
repo.merkle_store()
.unpack(sync_reader, UnpackOptions::Overwrite)?;
Ok(())
})
.await
.map_err(|e| OxenError::basic_str(format!("unpack task panicked: {e}")))??;

Ok(())
}
Expand Down Expand Up @@ -470,6 +470,7 @@ pub async fn mark_nodes_as_synced(
mod tests {
use crate::api;
use crate::error::OxenError;
use crate::model::{Remote, RemoteRepository};
use crate::opts::FetchOpts;
use crate::repositories;
use crate::test;
Expand Down Expand Up @@ -674,4 +675,55 @@ mod tests {
})
.await
}

/// Regression: a corrupted gzip stream returned by the remote server must surface
/// as an `Err(OxenError)` from [`download_tree`] (and therefore the underlying
/// private `node_download_request`), **not** a panic on the spawn_blocking
/// `JoinHandle`.
///
/// The `node_download_request` pipeline pipes the response body through
/// `StreamReader` → `SyncIoBridge` → [`MerkleUnpacker::unpack`] inside a
/// `tokio::task::spawn_blocking`. The unpack returns `Err` on garbage gzip;
/// that `Err` must be propagated through the join handle as an `OxenError`,
/// not lost to a panic. Mockito serves a fixed garbage body so the test
/// doesn't depend on the live oxen-server.
#[tokio::test]
async fn test_node_download_request_propagates_corrupted_gzip_error() -> Result<(), OxenError> {
test::run_empty_local_repo_test_async(|local_repo| async move {
let mut server = mockito::Server::new_async().await;
let server_url = server.url();
let namespace = "ns";
let name = "repo";
// download_tree uses URI "/tree/download", joined with API_NAMESPACE
// and the remote path -> "/api/repos/{namespace}/{name}/tree/download".
let path = format!("/api/repos/{namespace}/{name}/tree/download");
let _mock = server
.mock("GET", path.as_str())
.with_status(200)
.with_header("Content-Type", "application/gzip")
// Garbage bytes that aren't a valid gzip stream.
.with_body(b"NOT A VALID GZIP STREAM".as_slice())
.create_async()
.await;

let remote_repo = RemoteRepository {
namespace: namespace.to_string(),
name: name.to_string(),
remote: Remote {
name: "origin".to_string(),
url: format!("{server_url}/{namespace}/{name}"),
},
min_version: None,
is_empty: false,
};

let res = api::client::tree::download_tree(&local_repo, &remote_repo).await;
assert!(
res.is_err(),
"download_tree must return Err for corrupted gzip stream, got {res:?}"
);
Ok(())
})
.await
}
}
Loading
Loading