Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 101 additions & 8 deletions rust/src/mqtt/mqtt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use crate::applayer::*;
use crate::applayer::{self, LoggerFlags};
use crate::conf::conf_get;
use crate::core::*;
use crate::frames::*;
use nom7::Err;
use std;
use std::collections::VecDeque;
Expand All @@ -41,6 +42,13 @@ static mut MQTT_MAX_TX: usize = 1024;

static mut ALPROTO_MQTT: AppProto = ALPROTO_UNKNOWN;

#[derive(AppLayerFrameType)]
pub enum MQTTFrameType {
Pdu,
Header,
Data,
}

#[derive(FromPrimitive, Debug, AppLayerEvent)]
pub enum MQTTEvent {
MissingConnect,
Expand Down Expand Up @@ -422,8 +430,10 @@ impl MQTTState {
}
}

fn parse_request(&mut self, input: &[u8]) -> AppLayerResult {
fn parse_request(&mut self, flow: *const Flow, stream_slice: StreamSlice) -> AppLayerResult {
let input = stream_slice.as_slice();
let mut current = input;

if input.is_empty() {
return AppLayerResult::ok();
}
Expand Down Expand Up @@ -455,6 +465,13 @@ impl MQTTState {
SCLogDebug!("request: handling {}", current.len());
match parse_message(current, self.protocol_version, self.max_msg_len) {
Ok((rem, msg)) => {
let _pdu = Frame::new(
flow,
&stream_slice,
input,
current.len() as i64,
MQTTFrameType::Pdu as u8,
);
SCLogDebug!("request msg {:?}", msg);
if let MQTTOperation::TRUNCATED(ref trunc) = msg.op {
SCLogDebug!(
Expand All @@ -463,17 +480,21 @@ impl MQTTState {
current.len()
);
if trunc.skipped_length >= current.len() {
self.mqtt_frames_trunc(flow, &stream_slice, &trunc, &current, &msg);
self.skip_request = trunc.skipped_length - current.len();
self.handle_msg(msg, true);
return AppLayerResult::ok();
} else {
self.mqtt_frames_trunc(flow, &stream_slice, &trunc, &current, &msg);
consumed += trunc.skipped_length;
current = &current[trunc.skipped_length..];
self.handle_msg(msg, true);
self.skip_request = 0;
continue;
}
}

self.mqtt_frames(flow, &stream_slice, &msg);
self.handle_msg(msg, false);
consumed += current.len() - rem.len();
current = rem;
Expand All @@ -497,8 +518,10 @@ impl MQTTState {
return AppLayerResult::ok();
}

fn parse_response(&mut self, input: &[u8]) -> AppLayerResult {
fn parse_response(&mut self, flow: *const Flow, stream_slice: StreamSlice) -> AppLayerResult {
let input = stream_slice.as_slice();
let mut current = input;

if input.is_empty() {
return AppLayerResult::ok();
}
Expand Down Expand Up @@ -529,6 +552,14 @@ impl MQTTState {
SCLogDebug!("response: handling {}", current.len());
match parse_message(current, self.protocol_version, self.max_msg_len) {
Ok((rem, msg)) => {
let _pdu = Frame::new(
flow,
&stream_slice,
input,
input.len() as i64,
MQTTFrameType::Pdu as u8,
);

SCLogDebug!("response msg {:?}", msg);
if let MQTTOperation::TRUNCATED(ref trunc) = msg.op {
SCLogDebug!(
Expand All @@ -537,18 +568,22 @@ impl MQTTState {
current.len()
);
if trunc.skipped_length >= current.len() {
self.mqtt_frames_trunc(flow, &stream_slice, &trunc, &current, &msg);
self.skip_response = trunc.skipped_length - current.len();
self.handle_msg(msg, true);
SCLogDebug!("skip_response now {}", self.skip_response);
return AppLayerResult::ok();
} else {
self.mqtt_frames_trunc(flow, &stream_slice, &trunc, &current, &msg);
consumed += trunc.skipped_length;
current = &current[trunc.skipped_length..];
self.handle_msg(msg, true);
self.skip_response = 0;
continue;
}
}

self.mqtt_frames(flow, &stream_slice, &msg);
self.handle_msg(msg, true);
consumed += current.len() - rem.len();
current = rem;
Expand Down Expand Up @@ -589,6 +624,64 @@ impl MQTTState {
tx.tx_data.set_event(event as u8);
self.transactions.push_back(tx);
}

fn mqtt_frames(&mut self, flow: *const Flow, stream_slice: &StreamSlice, input: &MQTTMessage) {
let hdr = stream_slice.as_slice();
//MQTT payload has a fixed header of 2 bytes
let _mqtt_hdr = Frame::new(flow, stream_slice, hdr, 2, MQTTFrameType::Header as u8);
SCLogDebug!("mqtt_hdr Frame {:?}", _mqtt_hdr);
let rem_length = input.header.remaining_length as usize;
let data = &hdr[2..rem_length + 2];
let _mqtt_data = Frame::new(
flow,
stream_slice,
data,
rem_length as i64,
MQTTFrameType::Data as u8,
);
SCLogDebug!("mqtt_data Frame {:?}", _mqtt_data);
}
fn mqtt_frames_trunc(
&mut self, flow: *const Flow, stream_slice: &StreamSlice, input: &MQTTTruncatedData,
current: &[u8], msg: &MQTTMessage,
) {
let hdr = stream_slice.as_slice();
let hdr_length = input.skipped_length - msg.header.remaining_length as usize;
let _mqtt_hdr = Frame::new(
flow,
stream_slice,
hdr,
hdr_length as i64,
MQTTFrameType::Header as u8,
);
SCLogDebug!("mqtt_hdr Frame {:?}", _mqtt_hdr);

if input.skipped_length >= current.len() {
//taking current.len() as reference as trunc.skipped_length >= current.len()

let rem_length = current.len() - hdr_length as usize;
let data = &hdr[hdr_length..current.len()];
let _mqtt_data = Frame::new(
flow,
stream_slice,
data,
rem_length as i64,
MQTTFrameType::Data as u8,
);
SCLogDebug!("mqtt_data Frame {:?}", _mqtt_data);
} else {
let rem_length = input.skipped_length - hdr_length as usize;
let data = &hdr[hdr_length..rem_length + hdr_length];
let _mqtt_data = Frame::new(
flow,
stream_slice,
data,
rem_length as i64,
MQTTFrameType::Data as u8,
);
SCLogDebug!("mqtt_data Frame {:?}", _mqtt_data);
}
}
}

// C exports.
Expand Down Expand Up @@ -637,20 +730,20 @@ pub unsafe extern "C" fn rs_mqtt_state_tx_free(state: *mut std::os::raw::c_void,

#[no_mangle]
pub unsafe extern "C" fn rs_mqtt_parse_request(
_flow: *const Flow, state: *mut std::os::raw::c_void, _pstate: *mut std::os::raw::c_void,
flow: *const Flow, state: *mut std::os::raw::c_void, _pstate: *mut std::os::raw::c_void,
stream_slice: StreamSlice, _data: *const std::os::raw::c_void,
) -> AppLayerResult {
let state = cast_pointer!(state, MQTTState);
return state.parse_request(stream_slice.as_slice());
return state.parse_request(flow, stream_slice);
}

#[no_mangle]
pub unsafe extern "C" fn rs_mqtt_parse_response(
_flow: *const Flow, state: *mut std::os::raw::c_void, _pstate: *mut std::os::raw::c_void,
flow: *const Flow, state: *mut std::os::raw::c_void, _pstate: *mut std::os::raw::c_void,
stream_slice: StreamSlice, _data: *const std::os::raw::c_void,
) -> AppLayerResult {
let state = cast_pointer!(state, MQTTState);
return state.parse_response(stream_slice.as_slice());
return state.parse_response(flow, stream_slice);
}

#[no_mangle]
Expand Down Expand Up @@ -761,8 +854,8 @@ pub unsafe extern "C" fn rs_mqtt_register_parser(cfg_max_msg_len: u32) {
apply_tx_config: None,
flags: APP_LAYER_PARSER_OPT_UNIDIR_TXS,
truncate: None,
get_frame_id_by_name: None,
get_frame_name_by_id: None,
get_frame_id_by_name: Some(MQTTFrameType::ffi_id_from_name),
get_frame_name_by_id: Some(MQTTFrameType::ffi_name_from_id),
};

let ip_proto_str = CString::new("tcp").unwrap();
Expand Down