diff --git a/.ckignore b/.ckignore new file mode 100644 index 0000000..04e615f --- /dev/null +++ b/.ckignore @@ -0,0 +1,65 @@ +# .ckignore - Default patterns for ck semantic search +# Created automatically during first index +# Syntax: same as .gitignore (glob patterns, ! for negation) + +# Images +*.png +*.jpg +*.jpeg +*.gif +*.bmp +*.svg +*.ico +*.webp +*.tiff + +# Video +*.mp4 +*.avi +*.mov +*.mkv +*.wmv +*.flv +*.webm + +# Audio +*.mp3 +*.wav +*.flac +*.aac +*.ogg +*.m4a + +# Binary/Compiled +*.exe +*.dll +*.so +*.dylib +*.a +*.lib +*.obj +*.o + +# Archives +*.zip +*.tar +*.tar.gz +*.tgz +*.rar +*.7z +*.bz2 +*.gz + +# Data files +*.db +*.sqlite +*.sqlite3 +*.parquet +*.arrow + +# Config formats (issue #27) +*.json +*.yaml +*.yml + +# Add your custom patterns below this line diff --git a/.gitignore b/.gitignore index ea8c4bf..0a1be3f 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,5 @@ /target +.idea/ +.claude/ +.ck/ + diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..0b3eeba --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,53 @@ +# AGENTS.md + +## Project Overview + +Bore is a modern, simple TCP tunnel in Rust that exposes local ports to a remote server, bypassing NAT. It's an alternative to localtunnel/ngrok. The crate is published as `bore-cli` with binary name `bore`. + +## Build & Development Commands + +```bash +cargo build --all-features # Build +cargo test # Run all tests +cargo test basic_proxy # Run a single test by name +cargo fmt -- --check # Check formatting +cargo clippy -- -D warnings # Lint (CI treats warnings as errors) +RUST_LOG=debug cargo test # Run tests with debug logging +``` + +## Architecture + +The codebase is ~400 lines of async Rust using Tokio. No unsafe code (`#![forbid(unsafe_code)]`). + +### Modules + +- **`main.rs`** — CLI entry point using clap. Two subcommands: `local` (client) and `server`. The `local` subcommand includes a reconnection loop with exponential backoff (enabled by default, disable with `--no-reconnect`). Authentication errors are classified as fatal via `is_auth_error()` and never retried. +- **`shared.rs`** — Protocol definitions. `ClientMessage`/`ServerMessage` enums serialized as JSON over TCP with null-byte delimiters. `Delimited` wraps any async stream for framed JSON I/O. Key constants: `CONTROL_PORT = 7835`, `MAX_FRAME_LENGTH = 256`, `NETWORK_TIMEOUT = 3s`, `HEARTBEAT_TIMEOUT = 8s`. Also contains `ExponentialBackoff` for reconnection delays and `set_tcp_keepalive()` for OS-level dead connection detection. +- **`auth.rs`** — Optional HMAC-SHA256 challenge-response authentication. Secret is SHA256-hashed before use. Constant-time comparison. +- **`client.rs`** — `Client` connects to server's control port, sends `Hello(port)`, receives assigned port. The `listen()` method wraps `recv()` in a heartbeat timeout (8s) to detect dead connections, returning an error instead of blocking forever. TCP keepalive is set on the control connection. For each incoming `Connection(uuid)`, opens a new TCP connection, sends `Accept(uuid)`, then bidirectionally proxies between local service and tunnel. +- **`server.rs`** — `Server` listens on control port. Allocates tunnel ports (random selection, 150 attempts). Stores pending connections in `DashMap` with 10-second expiry. Sends heartbeats every 500ms. + +### Protocol Flow + +1. Client connects to server on control port (7835) +2. Optional auth: server sends `Challenge(uuid)`, client responds with `Authenticate(hmac)` +3. Client sends `Hello(desired_port)`, server responds with `Hello(actual_port)` and starts tunnel listener +4. When external traffic hits the tunnel port, server stores the connection by UUID, sends `Connection(uuid)` to client +5. Client opens a new connection to server, sends `Accept(uuid)`, server pairs streams, bidirectional copy begins +6. If the control connection drops (heartbeat timeout or EOF), the client reconnects automatically with exponential backoff (unless `--no-reconnect` is set) + +### Key Patterns + +- `Delimited` for framed JSON messaging with null-byte delimiters (via `tokio_util::codec`) +- `Arc` for lock-free concurrent pending connection storage +- `Arc`/`Arc` shared across spawned Tokio tasks +- `tokio::io::copy_bidirectional` for efficient TCP proxying +- `anyhow::Result` with `.context()` for error propagation +- Heartbeat timeout on client `listen()` loop to detect dead connections (8s timeout, server heartbeats every 500ms) +- Exponential backoff with jitter for reconnection delays (1s base, configurable max) +- TCP keepalive via `socket2` as defense-in-depth for dead connection detection +- String-based error classification (`is_auth_error()`) to distinguish fatal from retriable errors + +## Testing + +Tests are in `tests/e2e_test.rs` (integration) and `tests/auth_test.rs` (auth unit tests). Integration tests use a `lazy_static` mutex (`SERIAL_GUARD`) to run serially and avoid port conflicts. CI retries tests up to 3 times due to potential port contention. diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..34a4847 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,55 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +Bore is a modern, simple TCP tunnel in Rust that exposes local ports to a remote server, bypassing NAT. It's an alternative to localtunnel/ngrok. The crate is published as `bore-cli` with binary name `bore`. + +## Build & Development Commands + +```bash +cargo build --all-features # Build +cargo test # Run all tests +cargo test basic_proxy # Run a single test by name +cargo fmt -- --check # Check formatting +cargo clippy -- -D warnings # Lint (CI treats warnings as errors) +RUST_LOG=debug cargo test # Run tests with debug logging +``` + +## Architecture + +The codebase is ~400 lines of async Rust using Tokio. No unsafe code (`#![forbid(unsafe_code)]`). + +### Modules + +- **`main.rs`** — CLI entry point using clap. Two subcommands: `local` (client) and `server`. The `local` subcommand includes a reconnection loop with exponential backoff (enabled by default, disable with `--no-reconnect`). Authentication errors are classified as fatal via `is_auth_error()` and never retried. +- **`shared.rs`** — Protocol definitions. `ClientMessage`/`ServerMessage` enums serialized as JSON over TCP with null-byte delimiters. `Delimited` wraps any async stream for framed JSON I/O. Key constants: `CONTROL_PORT = 7835`, `MAX_FRAME_LENGTH = 256`, `NETWORK_TIMEOUT = 3s`, `HEARTBEAT_TIMEOUT = 8s`. Also contains `ExponentialBackoff` for reconnection delays and `set_tcp_keepalive()` for OS-level dead connection detection. +- **`auth.rs`** — Optional HMAC-SHA256 challenge-response authentication. Secret is SHA256-hashed before use. Constant-time comparison. +- **`client.rs`** — `Client` connects to server's control port, sends `Hello(port)`, receives assigned port. The `listen()` method wraps `recv()` in a heartbeat timeout (8s) to detect dead connections, returning an error instead of blocking forever. TCP keepalive is set on the control connection. For each incoming `Connection(uuid)`, opens a new TCP connection, sends `Accept(uuid)`, then bidirectionally proxies between local service and tunnel. +- **`server.rs`** — `Server` listens on control port. Allocates tunnel ports (random selection, 150 attempts). Stores pending connections in `DashMap` with 10-second expiry. Sends heartbeats every 500ms. + +### Protocol Flow + +1. Client connects to server on control port (7835) +2. Optional auth: server sends `Challenge(uuid)`, client responds with `Authenticate(hmac)` +3. Client sends `Hello(desired_port)`, server responds with `Hello(actual_port)` and starts tunnel listener +4. When external traffic hits the tunnel port, server stores the connection by UUID, sends `Connection(uuid)` to client +5. Client opens a new connection to server, sends `Accept(uuid)`, server pairs streams, bidirectional copy begins +6. If the control connection drops (heartbeat timeout or EOF), the client reconnects automatically with exponential backoff (unless `--no-reconnect` is set) + +### Key Patterns + +- `Delimited` for framed JSON messaging with null-byte delimiters (via `tokio_util::codec`) +- `Arc` for lock-free concurrent pending connection storage +- `Arc`/`Arc` shared across spawned Tokio tasks +- `tokio::io::copy_bidirectional` for efficient TCP proxying +- `anyhow::Result` with `.context()` for error propagation +- Heartbeat timeout on client `listen()` loop to detect dead connections (8s timeout, server heartbeats every 500ms) +- Exponential backoff with jitter for reconnection delays (1s base, configurable max) +- TCP keepalive via `socket2` as defense-in-depth for dead connection detection +- String-based error classification (`is_auth_error()`) to distinguish fatal from retriable errors + +## Testing + +Tests are in `tests/e2e_test.rs` (integration) and `tests/auth_test.rs` (auth unit tests). Integration tests use a `lazy_static` mutex (`SERIAL_GUARD`) to run serially and avoid port conflicts. CI retries tests up to 3 times due to potential port contention. diff --git a/Cargo.lock b/Cargo.lock index 17e8529..27b209f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -127,6 +127,7 @@ dependencies = [ "serde", "serde_json", "sha2", + "socket2 0.5.10", "tokio", "tokio-util", "tracing", @@ -474,9 +475,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.142" +version = "0.2.182" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a987beff54b60ffa6d51982e1aa1146bc42f19bd26be28b0586f252fccf5317" +checksum = "6800badb6cb2082ffd7b6a67e6125bb39f18782f793520caee8cb8846be06112" [[package]] name = "linux-raw-sys" @@ -768,6 +769,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "socket2" +version = "0.5.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "strsim" version = "0.10.0" @@ -824,7 +835,7 @@ dependencies = [ "mio", "num_cpus", "pin-project-lite", - "socket2", + "socket2 0.4.9", "tokio-macros", "windows-sys 0.48.0", ] @@ -998,6 +1009,15 @@ dependencies = [ "windows-targets 0.48.0", ] +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-targets" version = "0.42.2" @@ -1028,6 +1048,22 @@ dependencies = [ "windows_x86_64_msvc 0.48.0", ] +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + [[package]] name = "windows_aarch64_gnullvm" version = "0.42.2" @@ -1040,6 +1076,12 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + [[package]] name = "windows_aarch64_msvc" version = "0.42.2" @@ -1052,6 +1094,12 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3" +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + [[package]] name = "windows_i686_gnu" version = "0.42.2" @@ -1064,6 +1112,18 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241" +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + [[package]] name = "windows_i686_msvc" version = "0.42.2" @@ -1076,6 +1136,12 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00" +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + [[package]] name = "windows_x86_64_gnu" version = "0.42.2" @@ -1088,6 +1154,12 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1" +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + [[package]] name = "windows_x86_64_gnullvm" version = "0.42.2" @@ -1100,6 +1172,12 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + [[package]] name = "windows_x86_64_msvc" version = "0.42.2" @@ -1111,3 +1189,9 @@ name = "windows_x86_64_msvc" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" diff --git a/Cargo.toml b/Cargo.toml index a0b402a..8e91131 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,8 @@ hmac = "0.12.1" serde = { version = "1.0.136", features = ["derive"] } serde_json = "1.0.79" sha2 = "0.10.2" -tokio = { version = "1.17.0", features = ["rt-multi-thread", "io-util", "macros", "net", "time"] } +socket2 = { version = "0.5", features = ["all"] } +tokio = { version = "1.21.0", features = ["rt-multi-thread", "io-util", "macros", "net", "time"] } tokio-util = { version = "0.7.1", features = ["codec"] } tracing = "0.1.32" tracing-subscriber = "0.3.18" @@ -35,4 +36,4 @@ uuid = { version = "1.2.1", features = ["serde", "v4"] } [dev-dependencies] lazy_static = "1.4.0" rstest = "0.15.0" -tokio = { version = "1.17.0", features = ["sync"] } +tokio = { version = "1.21.0", features = ["sync"] } diff --git a/README.md b/README.md index 1e4e3d6..3c81ad5 100644 --- a/README.md +++ b/README.md @@ -96,11 +96,13 @@ Arguments: The local port to expose [env: BORE_LOCAL_PORT=] Options: - -l, --local-host The local host to expose [default: localhost] - -t, --to Address of the remote server to expose local ports to [env: BORE_SERVER=] - -p, --port Optional port on the remote server to select [default: 0] - -s, --secret Optional secret for authentication [env: BORE_SECRET] - -h, --help Print help + -l, --local-host The local host to expose [default: localhost] + -t, --to Address of the remote server to expose local ports to [env: BORE_SERVER=] + -p, --port Optional port on the remote server to select [default: 0] + -s, --secret Optional secret for authentication [env: BORE_SECRET] + --no-reconnect Disable automatic reconnection on connection loss + --max-reconnect-delay Maximum delay between reconnection attempts [default: 64] + -h, --help Print help ``` ### Self-Hosting @@ -139,6 +141,17 @@ Whenever the server obtains a connection on the remote port, it generates a secu For correctness reasons and to avoid memory leaks, incoming connections are only stored by the server for up to 10 seconds before being discarded if the client does not accept them. +## Reconnection + +By default, `bore` automatically reconnects to the server when the connection is lost (e.g., due to network interruptions). This makes it suitable for long-running deployments with service managers like systemd or launchd. + +- **Automatic reconnection** is enabled by default with exponential backoff (1s, 2s, 4s, ... up to 64s max) +- **Authentication failures** (wrong secret) are never retried — the client exits immediately +- **`--no-reconnect`** disables automatic reconnection, restoring the legacy exit-on-disconnect behavior +- **`--max-reconnect-delay `** configures the maximum backoff delay (default: 64 seconds) + +Dead connections are detected via a heartbeat timeout: the server sends heartbeats every 500ms, and if no message is received within 8 seconds, the client treats the connection as dead and begins reconnecting. TCP keepalive is also configured as an additional safety net. + ## Authentication On a custom deployment of `bore server`, you can optionally require a _secret_ to prevent the server from being used by others. The protocol requires clients to verify possession of the secret on each TCP connection by answering random challenges in the form of HMAC codes. (This secret is only used for the initial handshake, and no further traffic is encrypted by default.) diff --git a/src/auth.rs b/src/auth.rs index ce8237c..d21dae2 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -48,6 +48,9 @@ impl Authenticator { } /// As the server, send a challenge to the client and validate their response. + /// + /// NOTE: Error messages here are matched by `is_auth_error()` in `main.rs` to + /// classify fatal auth errors. Do not change these strings without updating that function. pub async fn server_handshake( &self, stream: &mut Delimited, @@ -64,6 +67,9 @@ impl Authenticator { } /// As the client, answer a challenge to attempt to authenticate with the server. + /// + /// NOTE: Error messages here are matched by `is_auth_error()` in `main.rs` to + /// classify fatal auth errors. Do not change these strings without updating that function. pub async fn client_handshake( &self, stream: &mut Delimited, diff --git a/src/client.rs b/src/client.rs index cb8fa7b..f832ffc 100644 --- a/src/client.rs +++ b/src/client.rs @@ -8,7 +8,10 @@ use tracing::{error, info, info_span, warn, Instrument}; use uuid::Uuid; use crate::auth::Authenticator; -use crate::shared::{ClientMessage, Delimited, ServerMessage, CONTROL_PORT, NETWORK_TIMEOUT}; +use crate::shared::{ + set_tcp_keepalive, ClientMessage, Delimited, ServerMessage, CONTROL_PORT, HEARTBEAT_TIMEOUT, + NETWORK_TIMEOUT, +}; /// State structure for the client. pub struct Client { @@ -40,7 +43,11 @@ impl Client { port: u16, secret: Option<&str>, ) -> Result { - let mut stream = Delimited::new(connect_with_timeout(to, CONTROL_PORT).await?); + let tcp_stream = connect_with_timeout(to, CONTROL_PORT).await?; + if let Err(e) = set_tcp_keepalive(&tcp_stream) { + warn!("TCP keepalive not available: {e:#}"); + } + let mut stream = Delimited::new(tcp_stream); let auth = secret.map(Authenticator::new); if let Some(auth) = &auth { auth.client_handshake(&mut stream).await?; @@ -51,6 +58,7 @@ impl Client { Some(ServerMessage::Hello(remote_port)) => remote_port, Some(ServerMessage::Error(message)) => bail!("server error: {message}"), Some(ServerMessage::Challenge(_)) => { + // NOTE: This message is matched by is_auth_error() in main.rs. bail!("server requires authentication, but no client secret was provided"); } Some(_) => bail!("unexpected initial non-hello message"), @@ -79,25 +87,32 @@ impl Client { let mut conn = self.conn.take().unwrap(); let this = Arc::new(self); loop { - match conn.recv().await? { - Some(ServerMessage::Hello(_)) => warn!("unexpected hello"), - Some(ServerMessage::Challenge(_)) => warn!("unexpected challenge"), - Some(ServerMessage::Heartbeat) => (), - Some(ServerMessage::Connection(id)) => { - let this = Arc::clone(&this); - tokio::spawn( - async move { - info!("new connection"); - match this.handle_connection(id).await { - Ok(_) => info!("connection exited"), - Err(err) => warn!(%err, "connection exited with error"), - } - } - .instrument(info_span!("proxy", %id)), - ); + match timeout(HEARTBEAT_TIMEOUT, conn.recv()).await { + Err(_elapsed) => { + // No message received for HEARTBEAT_TIMEOUT seconds. + // Server sends heartbeats every 500ms, so connection is dead. + bail!("heartbeat timeout, connection to server lost"); } - Some(ServerMessage::Error(err)) => error!(%err, "server error"), - None => return Ok(()), + Ok(msg) => match msg? { + Some(ServerMessage::Hello(_)) => warn!("unexpected hello"), + Some(ServerMessage::Challenge(_)) => warn!("unexpected challenge"), + Some(ServerMessage::Heartbeat) => (), + Some(ServerMessage::Connection(id)) => { + let this = Arc::clone(&this); + tokio::spawn( + async move { + info!("new connection"); + match this.handle_connection(id).await { + Ok(_) => info!("connection exited"), + Err(err) => warn!(%err, "connection exited with error"), + } + } + .instrument(info_span!("proxy", %id)), + ); + } + Some(ServerMessage::Error(err)) => error!(%err, "server error"), + None => bail!("server closed connection"), + }, } } } diff --git a/src/main.rs b/src/main.rs index 71429c4..660180f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,11 @@ use std::net::IpAddr; +use std::time::Duration; use anyhow::Result; +use bore_cli::shared::ExponentialBackoff; use bore_cli::{client::Client, server::Server}; use clap::{error::ErrorKind, CommandFactory, Parser, Subcommand}; +use tracing::{info, warn}; #[derive(Parser, Debug)] #[clap(author, version, about)] @@ -34,6 +37,14 @@ enum Command { /// Optional secret for authentication. #[clap(short, long, env = "BORE_SECRET", hide_env_values = true)] secret: Option, + + /// Disable automatic reconnection on connection loss. + #[clap(long)] + no_reconnect: bool, + + /// Maximum delay between reconnection attempts, in seconds. + #[clap(long, default_value_t = 64, value_name = "SECONDS", value_parser = clap::value_parser!(u64).range(1..))] + max_reconnect_delay: u64, }, /// Runs the remote proxy server. @@ -60,6 +71,15 @@ enum Command { }, } +/// Check if an error is an authentication error that should not be retried. +fn is_auth_error(err: &anyhow::Error) -> bool { + let msg = format!("{err:#}"); + msg.contains("server requires authentication") + || msg.contains("invalid secret") + || msg.contains("server requires secret") + || msg.contains("expected authentication challenge") +} + #[tokio::main] async fn run(command: Command) -> Result<()> { match command { @@ -69,9 +89,57 @@ async fn run(command: Command) -> Result<()> { to, port, secret, + no_reconnect, + max_reconnect_delay, } => { - let client = Client::new(&local_host, local_port, &to, port, secret.as_deref()).await?; - client.listen().await?; + if no_reconnect { + // Legacy behavior: exit cleanly on disconnection, fail on auth errors + let client = + Client::new(&local_host, local_port, &to, port, secret.as_deref()).await?; + match client.listen().await { + Ok(()) => {} + Err(e) if is_auth_error(&e) => return Err(e), + Err(e) => { + warn!("disconnected: {e:#}"); + } + } + } else { + // Reconnection mode: infinite retry on transient failures + let mut backoff = ExponentialBackoff::new( + Duration::from_secs(1), + Duration::from_secs(max_reconnect_delay), + ); + + loop { + match Client::new(&local_host, local_port, &to, port, secret.as_deref()).await { + Ok(client) => { + backoff.reset(); + info!("connected to server"); + match client.listen().await { + Ok(()) => { + warn!("listen() exited cleanly, reconnecting"); + } + Err(e) => { + if is_auth_error(&e) { + return Err(e); + } + warn!("connection lost: {e:#}"); + } + } + } + Err(e) => { + if is_auth_error(&e) { + return Err(e); + } + warn!("connection failed: {e:#}"); + } + } + + let delay = backoff.next_delay(); + info!("reconnecting in {delay:.1?}..."); + tokio::time::sleep(delay).await; + } + } } Command::Server { min_port, @@ -100,3 +168,37 @@ fn main() -> Result<()> { tracing_subscriber::fmt::init(); run(Args::parse().command) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_auth_error_detection() { + // Fatal auth errors — should NOT be retried + assert!(is_auth_error(&anyhow::anyhow!( + "server requires authentication, but no client secret was provided" + ))); + assert!(is_auth_error(&anyhow::anyhow!( + "server error: invalid secret" + ))); + assert!(is_auth_error(&anyhow::anyhow!( + "server error: server requires secret, but no secret was provided" + ))); + assert!(is_auth_error(&anyhow::anyhow!( + "expected authentication challenge, but no secret was required" + ))); + + // Retriable errors — should be retried + assert!(!is_auth_error(&anyhow::anyhow!( + "could not connect to server:7835" + ))); + assert!(!is_auth_error(&anyhow::anyhow!( + "heartbeat timeout, connection to server lost" + ))); + assert!(!is_auth_error(&anyhow::anyhow!( + "server error: port already in use" + ))); + assert!(!is_auth_error(&anyhow::anyhow!("server closed connection"))); + } +} diff --git a/src/server.rs b/src/server.rs index f47d714..0a47bb8 100644 --- a/src/server.rs +++ b/src/server.rs @@ -12,7 +12,7 @@ use tracing::{info, info_span, warn, Instrument}; use uuid::Uuid; use crate::auth::Authenticator; -use crate::shared::{ClientMessage, Delimited, ServerMessage, CONTROL_PORT}; +use crate::shared::{set_tcp_keepalive, ClientMessage, Delimited, ServerMessage, CONTROL_PORT}; /// State structure for the server. pub struct Server { @@ -116,6 +116,9 @@ impl Server { } async fn handle_connection(&self, stream: TcpStream) -> Result<()> { + if let Err(e) = set_tcp_keepalive(&stream) { + warn!("TCP keepalive not available: {e:#}"); + } let mut stream = Delimited::new(stream); if let Some(auth) = &self.auth { if let Err(err) = auth.server_handshake(&mut stream).await { diff --git a/src/shared.rs b/src/shared.rs index d9c5d3b..a50b221 100644 --- a/src/shared.rs +++ b/src/shared.rs @@ -5,7 +5,9 @@ use std::time::Duration; use anyhow::{Context, Result}; use futures_util::{SinkExt, StreamExt}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use socket2::{SockRef, TcpKeepalive}; use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::TcpStream; use tokio::time::timeout; use tokio_util::codec::{AnyDelimiterCodec, Framed, FramedParts}; use tracing::trace; @@ -20,6 +22,12 @@ pub const MAX_FRAME_LENGTH: usize = 256; /// Timeout for network connections and initial protocol messages. pub const NETWORK_TIMEOUT: Duration = Duration::from_secs(3); +/// Timeout for detecting a dead control connection. +/// +/// The server sends heartbeats every 500ms. If no message is received within +/// this duration, the connection is considered dead. +pub const HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(8); + /// A message from the client on the control connection. #[derive(Debug, Serialize, Deserialize)] pub enum ClientMessage { @@ -92,8 +100,95 @@ impl Delimited { Ok(()) } + /// Get a reference to the underlying transport stream. + pub fn get_ref(&self) -> &U { + self.0.get_ref() + } + /// Consume this object, returning current buffers and the inner transport. pub fn into_parts(self) -> FramedParts { self.0.into_parts() } } + +/// Simple exponential backoff with jitter for reconnection delays. +pub struct ExponentialBackoff { + current: Duration, + base: Duration, + max: Duration, +} + +impl ExponentialBackoff { + /// Create a new exponential backoff starting at `base` delay, capped at `max`. + pub fn new(base: Duration, max: Duration) -> Self { + Self { + current: base, + base, + max, + } + } + + /// Get the next delay and advance the backoff state. + /// Includes random jitter of +/- 25% to prevent thundering herd. + pub fn next_delay(&mut self) -> Duration { + let delay = self.current; + self.current = (self.current * 2).min(self.max); + // Add jitter: multiply by random factor between 0.75 and 1.25 + let jitter_factor = 0.75 + fastrand::f64() * 0.5; + delay.mul_f64(jitter_factor) + } + + /// Reset backoff to initial delay (call after successful connection). + pub fn reset(&mut self) { + self.current = self.base; + } +} + +/// Configure TCP keepalive on a stream for faster dead connection detection. +/// +/// This sets the OS to start probing after 30s of idle, probe every 10s, +/// and give up after 3 failed probes (~60s total to detect a dead connection). +pub fn set_tcp_keepalive(stream: &TcpStream) -> Result<()> { + let sock_ref = SockRef::from(stream); + let keepalive = TcpKeepalive::new() + .with_time(Duration::from_secs(30)) + .with_interval(Duration::from_secs(10)) + .with_retries(3); + sock_ref + .set_tcp_keepalive(&keepalive) + .context("failed to set TCP keepalive")?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_backoff_sequence() { + let mut backoff = ExponentialBackoff::new(Duration::from_secs(1), Duration::from_secs(30)); + // Delays should roughly double: 1, 2, 4, 8, 16, 30 (capped), 30, ... + // With jitter, each delay is between 0.75x and 1.25x the base + for expected_base in [1, 2, 4, 8, 16, 30, 30] { + let delay = backoff.next_delay(); + let min = Duration::from_secs(expected_base).mul_f64(0.75); + let max = Duration::from_secs(expected_base).mul_f64(1.25); + assert!( + delay >= min && delay <= max, + "delay {delay:?} out of range [{min:?}, {max:?}]" + ); + } + } + + #[test] + fn test_backoff_reset() { + let mut backoff = ExponentialBackoff::new(Duration::from_secs(1), Duration::from_secs(60)); + backoff.next_delay(); // 1s + backoff.next_delay(); // 2s + backoff.next_delay(); // 4s + backoff.reset(); + let delay = backoff.next_delay(); + // After reset, should be back to ~1s (with jitter) + assert!(delay < Duration::from_secs(2)); + } +}