Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 104 additions & 9 deletions rust/src/mqtt/mqtt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use crate::applayer::*;
use crate::applayer::{self, LoggerFlags};
use crate::conf::conf_get;
use crate::core::*;
use crate::frames::*;
use nom7::Err;
use std;
use std::collections::VecDeque;
Expand All @@ -41,6 +42,14 @@ static mut MQTT_MAX_TX: usize = 1024;

static mut ALPROTO_MQTT: AppProto = ALPROTO_UNKNOWN;

#[derive(AppLayerFrameType)]
pub enum MQTTFrameType {
Pdu,
Header,
Data,
TruncData,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following the discussion that happened privately, this data frame type does not seem to be necessary as we cannot annotate something that does not exist. Even if the data is truncated, the parsed parts can be tagged with Pdu, so, we can get rid of this. Make sure to add tests corresponding to truncated data. e.g.

aaaaaaaaaaaaaaabbbbbbbbbb 

where a..a == data exactly the length limit
b..b == extra data that should ideally be truncated by the parser.
So, you have three tests.

  1. Well within the boundary.
    check for say first 4B of data in Pdu, it should match.
  2. Boundary test. Should match.
    check for all aaaa exactly up to the length limit.
  3. Post boundary test. Should not match.
    check for aaaa...bbb i.e. good data + truncated data

}

#[derive(FromPrimitive, Debug, AppLayerEvent)]
pub enum MQTTEvent {
MissingConnect,
Expand Down Expand Up @@ -422,8 +431,10 @@ impl MQTTState {
}
}

fn parse_request(&mut self, input: &[u8]) -> AppLayerResult {
fn parse_request(&mut self, flow: *const Flow, stream_slice: StreamSlice) -> AppLayerResult {
let input = stream_slice.as_slice();
let mut current = input;

if input.is_empty() {
return AppLayerResult::ok();
}
Expand Down Expand Up @@ -455,6 +466,13 @@ impl MQTTState {
SCLogDebug!("request: handling {}", current.len());
match parse_message(current, self.protocol_version, self.max_msg_len) {
Ok((rem, msg)) => {
let _pdu = Frame::new(
flow,
&stream_slice,
input,
current.len() as i64,
MQTTFrameType::Pdu as u8,
);
SCLogDebug!("request msg {:?}", msg);
if let MQTTOperation::TRUNCATED(ref trunc) = msg.op {
SCLogDebug!(
Expand All @@ -463,17 +481,21 @@ impl MQTTState {
current.len()
);
if trunc.skipped_length >= current.len() {
self.mqtt_hdr_and_data_frames_trunc(flow, &stream_slice, trunc, current, &msg);
self.skip_request = trunc.skipped_length - current.len();
self.handle_msg(msg, true);
return AppLayerResult::ok();
} else {
self.mqtt_hdr_and_data_frames_trunc(flow, &stream_slice, trunc, current, &msg);
consumed += trunc.skipped_length;
current = &current[trunc.skipped_length..];
self.handle_msg(msg, true);
self.skip_request = 0;
continue;
}
}

self.mqtt_hdr_and_data_frames(flow, &stream_slice, &msg);
self.handle_msg(msg, false);
consumed += current.len() - rem.len();
current = rem;
Expand All @@ -497,8 +519,10 @@ impl MQTTState {
return AppLayerResult::ok();
}

fn parse_response(&mut self, input: &[u8]) -> AppLayerResult {
fn parse_response(&mut self, flow: *const Flow, stream_slice: StreamSlice) -> AppLayerResult {
let input = stream_slice.as_slice();
let mut current = input;

if input.is_empty() {
return AppLayerResult::ok();
}
Expand Down Expand Up @@ -529,6 +553,14 @@ impl MQTTState {
SCLogDebug!("response: handling {}", current.len());
match parse_message(current, self.protocol_version, self.max_msg_len) {
Ok((rem, msg)) => {
let _pdu = Frame::new(
flow,
&stream_slice,
input,
input.len() as i64,
MQTTFrameType::Pdu as u8,
);

SCLogDebug!("response msg {:?}", msg);
if let MQTTOperation::TRUNCATED(ref trunc) = msg.op {
SCLogDebug!(
Expand All @@ -537,18 +569,22 @@ impl MQTTState {
current.len()
);
if trunc.skipped_length >= current.len() {
self.mqtt_hdr_and_data_frames_trunc(flow, &stream_slice, trunc, current, &msg);
self.skip_response = trunc.skipped_length - current.len();
self.handle_msg(msg, true);
SCLogDebug!("skip_response now {}", self.skip_response);
return AppLayerResult::ok();
} else {
self.mqtt_hdr_and_data_frames_trunc(flow, &stream_slice, trunc, current, &msg);
consumed += trunc.skipped_length;
current = &current[trunc.skipped_length..];
self.handle_msg(msg, true);
self.skip_response = 0;
continue;
}
}

self.mqtt_hdr_and_data_frames(flow, &stream_slice, &msg);
self.handle_msg(msg, true);
consumed += current.len() - rem.len();
current = rem;
Expand Down Expand Up @@ -589,6 +625,65 @@ impl MQTTState {
tx.tx_data.set_event(event as u8);
self.transactions.push_back(tx);
}

fn mqtt_hdr_and_data_frames(&mut self, flow: *const Flow, stream_slice: &StreamSlice, input: &MQTTMessage) {
let hdr = stream_slice.as_slice();
//MQTT payload has a fixed header of 2 bytes
let _mqtt_hdr = Frame::new(flow, stream_slice, hdr, 2, MQTTFrameType::Header as u8);
SCLogDebug!("mqtt_hdr Frame {:?}", _mqtt_hdr);
let rem_length = input.header.remaining_length as usize;
let data = &hdr[2..rem_length + 2];
let _mqtt_data = Frame::new(
flow,
stream_slice,
data,
rem_length as i64,
MQTTFrameType::Data as u8,
);
SCLogDebug!("mqtt_data Frame {:?}", _mqtt_data);
}

fn mqtt_hdr_and_data_frames_trunc(
&mut self, flow: *const Flow, stream_slice: &StreamSlice, input: &MQTTTruncatedData,
current: &[u8], msg: &MQTTMessage,
) {
let hdr = stream_slice.as_slice();
let hdr_length = input.skipped_length - msg.header.remaining_length as usize;
let _mqtt_hdr = Frame::new(
flow,
stream_slice,
hdr,
hdr_length as i64,
MQTTFrameType::Header as u8,
);
SCLogDebug!("mqtt_hdr Frame {:?}", _mqtt_hdr);

if input.skipped_length >= current.len() {
//taking current.len() as reference as trunc.skipped_length >= current.len()

let rem_length = current.len() - hdr_length;
let data = &hdr[hdr_length..current.len()];
let _mqtt_data = Frame::new(
flow,
stream_slice,
data,
rem_length as i64,
MQTTFrameType::TruncData as u8,
);
SCLogDebug!("mqtt_data Frame {:?}", _mqtt_data);
} else {
let rem_length = input.skipped_length - hdr_length;
let data = &hdr[hdr_length..rem_length + hdr_length];
let _mqtt_data = Frame::new(
flow,
stream_slice,
data,
rem_length as i64,
MQTTFrameType::TruncData as u8,
);
SCLogDebug!("mqtt_data Frame {:?}", _mqtt_data);
}
}
}

// C exports.
Expand Down Expand Up @@ -637,20 +732,20 @@ pub unsafe extern "C" fn rs_mqtt_state_tx_free(state: *mut std::os::raw::c_void,

#[no_mangle]
pub unsafe extern "C" fn rs_mqtt_parse_request(
_flow: *const Flow, state: *mut std::os::raw::c_void, _pstate: *mut std::os::raw::c_void,
flow: *const Flow, state: *mut std::os::raw::c_void, _pstate: *mut std::os::raw::c_void,
stream_slice: StreamSlice, _data: *const std::os::raw::c_void,
) -> AppLayerResult {
let state = cast_pointer!(state, MQTTState);
return state.parse_request(stream_slice.as_slice());
return state.parse_request(flow, stream_slice);
}

#[no_mangle]
pub unsafe extern "C" fn rs_mqtt_parse_response(
_flow: *const Flow, state: *mut std::os::raw::c_void, _pstate: *mut std::os::raw::c_void,
flow: *const Flow, state: *mut std::os::raw::c_void, _pstate: *mut std::os::raw::c_void,
stream_slice: StreamSlice, _data: *const std::os::raw::c_void,
) -> AppLayerResult {
let state = cast_pointer!(state, MQTTState);
return state.parse_response(stream_slice.as_slice());
return state.parse_response(flow, stream_slice);
}

#[no_mangle]
Expand Down Expand Up @@ -761,8 +856,8 @@ pub unsafe extern "C" fn rs_mqtt_register_parser(cfg_max_msg_len: u32) {
apply_tx_config: None,
flags: APP_LAYER_PARSER_OPT_UNIDIR_TXS,
truncate: None,
get_frame_id_by_name: None,
get_frame_name_by_id: None,
get_frame_id_by_name: Some(MQTTFrameType::ffi_id_from_name),
get_frame_name_by_id: Some(MQTTFrameType::ffi_name_from_id),
};

let ip_proto_str = CString::new("tcp").unwrap();
Expand All @@ -783,4 +878,4 @@ pub unsafe extern "C" fn rs_mqtt_register_parser(cfg_max_msg_len: u32) {
} else {
SCLogDebug!("Protocol detector and parser disabled for MQTT.");
}
}
}