Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion zbus/src/address/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ impl Transport {

#[cfg(not(unix))]
{
let _ = path;
let _ = stream;
Comment thread
elmarco marked this conversation as resolved.
Err(Error::Unsupported)
}
}
Expand Down
60 changes: 30 additions & 30 deletions zbus/src/connection/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use crate::{
async_lock::RwLock,
names::{InterfaceName, UniqueName, WellKnownName},
object_server::Interface,
Connection, Error, Executor, Guid, Result,
Connection, Error, Executor, Guid, OwnedGuid, Result,
};

use super::{
Expand Down Expand Up @@ -161,7 +161,7 @@ impl<'a> Builder<'a> {

/// Create a builder for connection that will use the given socket.
pub fn socket<S: Socket + 'static>(socket: S) -> Self {
Self::new(Target::Socket(Split::new_boxed(socket)))
Self::new(Target::Socket(socket.into()))
}

/// Specify the mechanisms to use during authentication.
Expand Down Expand Up @@ -343,17 +343,11 @@ impl<'a> Builder<'a> {
}

async fn build_(mut self, executor: Executor<'static>) -> Result<Connection> {
let mut stream = self.stream_for_target().await?;
let (mut stream, server_guid) = self.target_connect().await?;
Comment thread
elmarco marked this conversation as resolved.
let mut auth = match self.guid {
None => {
let guid = match self.target {
Some(Target::Address(ref addr)) => {
addr.guid().map(|guid| guid.to_owned().into())
}
_ => None,
};
// SASL Handshake
Authenticated::client(stream, guid, self.auth_mechanisms).await?
Authenticated::client(stream, server_guid, self.auth_mechanisms).await?
}
Some(guid) => {
if !self.p2p {
Expand Down Expand Up @@ -456,36 +450,42 @@ impl<'a> Builder<'a> {
}
}

async fn stream_for_target(&mut self) -> Result<BoxedSplit> {
// SAFETY: `self.target` is always `Some` from the beginning and this methos is only called
async fn target_connect(&mut self) -> Result<(BoxedSplit, Option<OwnedGuid>)> {
// SAFETY: `self.target` is always `Some` from the beginning and this method is only called
// once.
Ok(match self.target.take().unwrap() {
let split = match self.target.take().unwrap() {
#[cfg(not(feature = "tokio"))]
Target::UnixStream(stream) => Split::new_boxed(Async::new(stream)?),
Target::UnixStream(stream) => Async::new(stream)?.into(),
#[cfg(all(unix, feature = "tokio"))]
Target::UnixStream(stream) => Split::new_boxed(stream),
Target::UnixStream(stream) => stream.into(),
#[cfg(all(not(unix), feature = "tokio"))]
Target::UnixStream(_) => return Err(Error::Unsupported),
#[cfg(not(feature = "tokio"))]
Target::TcpStream(stream) => Split::new_boxed(Async::new(stream)?),
Target::TcpStream(stream) => Async::new(stream)?.into(),
#[cfg(feature = "tokio")]
Target::TcpStream(stream) => Split::new_boxed(stream),
Target::TcpStream(stream) => stream.into(),
#[cfg(all(feature = "vsock", not(feature = "tokio")))]
Target::VsockStream(stream) => Split::new_boxed(Async::new(stream)?),
Target::VsockStream(stream) => Async::new(stream)?.into(),
#[cfg(feature = "tokio-vsock")]
Target::VsockStream(stream) => Split::new_boxed(stream),
Target::Address(address) => match address.connect().await? {
#[cfg(any(unix, not(feature = "tokio")))]
address::transport::Stream::Unix(stream) => Split::new_boxed(stream),
address::transport::Stream::Tcp(stream) => Split::new_boxed(stream),
#[cfg(any(
all(feature = "vsock", not(feature = "tokio")),
feature = "tokio-vsock"
))]
address::transport::Stream::Vsock(stream) => Split::new_boxed(stream),
},
Target::VsockStream(stream) => stream.into(),
Target::Address(address) => {
let guid = address.guid().map(|g| g.to_owned().into());
let split = match address.connect().await? {
#[cfg(any(unix, not(feature = "tokio")))]
address::transport::Stream::Unix(stream) => stream.into(),
address::transport::Stream::Tcp(stream) => stream.into(),
#[cfg(any(
all(feature = "vsock", not(feature = "tokio")),
feature = "tokio-vsock"
))]
address::transport::Stream::Vsock(stream) => stream.into(),
};
return Ok((split, guid));
}
Target::Socket(stream) => stream,
})
};

Ok((split, None))
}
}

Expand Down
16 changes: 8 additions & 8 deletions zbus/src/connection/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1045,7 +1045,7 @@ mod tests {

use super::*;

use crate::{connection::socket::Split, Guid, Socket};
use crate::{Guid, Socket};

fn create_async_socket_pair() -> (impl AsyncWrite + Socket, impl AsyncWrite + Socket) {
// Tokio needs us to call the sync function from async context. :shrug:
Expand All @@ -1071,9 +1071,9 @@ mod tests {
let (p0, p1) = create_async_socket_pair();

let guid = OwnedGuid::from(Guid::generate());
let client = ClientHandshake::new(Split::new_boxed(p0), None, Some(guid.clone()));
let client = ClientHandshake::new(p0.into(), None, Some(guid.clone()));
let server = ServerHandshake::new(
Split::new_boxed(p1),
p1.into(),
guid,
Some(Uid::effective().into()),
None,
Expand All @@ -1097,7 +1097,7 @@ mod tests {
fn pipelined_handshake() {
let (mut p0, p1) = create_async_socket_pair();
let server = ServerHandshake::new(
Split::new_boxed(p1),
p1.into(),
Guid::generate().into(),
Some(Uid::effective().into()),
None,
Expand Down Expand Up @@ -1126,7 +1126,7 @@ mod tests {
fn separate_external_data() {
let (mut p0, p1) = create_async_socket_pair();
let server = ServerHandshake::new(
Split::new_boxed(p1),
p1.into(),
Guid::generate().into(),
Some(Uid::effective().into()),
None,
Expand All @@ -1153,7 +1153,7 @@ mod tests {
fn missing_external_data() {
let (mut p0, p1) = create_async_socket_pair();
let server = ServerHandshake::new(
Split::new_boxed(p1),
p1.into(),
Guid::generate().into(),
Some(Uid::effective().into()),
None,
Expand All @@ -1171,7 +1171,7 @@ mod tests {
fn anonymous_handshake() {
let (mut p0, p1) = create_async_socket_pair();
let server = ServerHandshake::new(
Split::new_boxed(p1),
p1.into(),
Guid::generate().into(),
Some(Uid::effective().into()),
Some(vec![AuthMechanism::Anonymous].into()),
Expand All @@ -1189,7 +1189,7 @@ mod tests {
fn separate_anonymous_data() {
let (mut p0, p1) = create_async_socket_pair();
let server = ServerHandshake::new(
Split::new_boxed(p1),
p1.into(),
Guid::generate().into(),
Some(Uid::effective().into()),
Some(vec![AuthMechanism::Anonymous].into()),
Expand Down
21 changes: 11 additions & 10 deletions zbus/src/connection/socket/split.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,6 @@ pub struct Split<R: ReadHalf, W: WriteHalf> {
}

impl<R: ReadHalf, W: WriteHalf> Split<R, W> {
/// Create a new boxed `Split` from `socket`.
pub fn new_boxed<S: Socket<ReadHalf = R, WriteHalf = W>>(socket: S) -> BoxedSplit {
let split = socket.split();

Split {
read: Box::new(split.read),
write: Box::new(split.write),
}
}

/// Reference to the read half.
pub fn read(&self) -> &R {
&self.read
Expand Down Expand Up @@ -46,3 +36,14 @@ impl<R: ReadHalf, W: WriteHalf> Split<R, W> {

/// A boxed `Split`.
pub type BoxedSplit = Split<Box<dyn ReadHalf>, Box<dyn WriteHalf>>;

impl<S: Socket> From<S> for BoxedSplit {
fn from(socket: S) -> Self {
let split = socket.split();

Split {
read: Box::new(split.read),
write: Box::new(split.write),
}
}
}
57 changes: 36 additions & 21 deletions zbus/src/connection/socket/tcp.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
#[cfg(not(feature = "tokio"))]
use crate::fdo::ConnectionCredentials;
#[cfg(not(feature = "tokio"))]
use async_io::Async;
use std::io;
#[cfg(unix)]
Expand Down Expand Up @@ -28,7 +26,7 @@ impl ReadHalf for Arc<Async<TcpStream>> {
}
}

async fn peer_credentials(&mut self) -> io::Result<ConnectionCredentials> {
async fn peer_credentials(&mut self) -> io::Result<crate::fdo::ConnectionCredentials> {
#[cfg(windows)]
let creds = {
let stream = self.clone();
Expand All @@ -40,7 +38,7 @@ impl ReadHalf for Arc<Async<TcpStream>> {
let sid = ProcessToken::open(if pid != 0 { Some(pid as _) } else { None })
.and_then(|process_token| process_token.sid())?;
io::Result::Ok(
ConnectionCredentials::default()
crate::fdo::ConnectionCredentials::default()
.set_process_id(pid)
.set_windows_sid(sid),
)
Expand All @@ -51,7 +49,7 @@ impl ReadHalf for Arc<Async<TcpStream>> {
}?;

#[cfg(not(windows))]
let creds = ConnectionCredentials::default();
let creds = crate::fdo::ConnectionCredentials::default();

Ok(creds)
}
Expand Down Expand Up @@ -85,7 +83,7 @@ impl WriteHalf for Arc<Async<TcpStream>> {
.await
}

async fn peer_credentials(&mut self) -> io::Result<ConnectionCredentials> {
async fn peer_credentials(&mut self) -> io::Result<crate::fdo::ConnectionCredentials> {
ReadHalf::peer_credentials(self).await
}
}
Expand Down Expand Up @@ -119,21 +117,13 @@ impl ReadHalf for tokio::net::tcp::OwnedReadHalf {
}

#[cfg(windows)]
fn peer_sid(&self) -> Option<String> {
use crate::win32::{socket_addr_get_pid, ProcessToken};

let peer_addr = match self.peer_addr() {
Ok(addr) => addr,
Err(_) => return None,
};

if let Ok(pid) = socket_addr_get_pid(&peer_addr) {
if let Ok(process_token) = ProcessToken::open(if pid != 0 { Some(pid) } else { None }) {
return process_token.sid().ok();
}
}

None
async fn peer_credentials(&mut self) -> io::Result<crate::fdo::ConnectionCredentials> {
let peer_addr = self.peer_addr()?.clone();
crate::Task::spawn_blocking(
move || win32_credentials_from_addr(&peer_addr),
"peer credentials",
)
.await
}
}

Expand Down Expand Up @@ -161,4 +151,29 @@ impl WriteHalf for tokio::net::tcp::OwnedWriteHalf {
async fn close(&mut self) -> io::Result<()> {
tokio::io::AsyncWriteExt::shutdown(self).await
}

#[cfg(windows)]
async fn peer_credentials(&mut self) -> io::Result<crate::fdo::ConnectionCredentials> {
let peer_addr = self.peer_addr()?.clone();
crate::Task::spawn_blocking(
move || win32_credentials_from_addr(&peer_addr),
"peer credentials",
)
.await
}
}

#[cfg(feature = "tokio")]
#[cfg(windows)]
fn win32_credentials_from_addr(
addr: &std::net::SocketAddr,
) -> io::Result<crate::fdo::ConnectionCredentials> {
use crate::win32::{socket_addr_get_pid, ProcessToken};

let pid = socket_addr_get_pid(addr)? as _;
let sid = ProcessToken::open(if pid != 0 { Some(pid as _) } else { None })
.and_then(|process_token| process_token.sid())?;
Ok(crate::fdo::ConnectionCredentials::default()
.set_process_id(pid)
.set_windows_sid(sid))
}
Loading