diff --git a/mqtt/src/connect.rs b/mqtt/src/connect.rs index f8b6ebc0..def1ba88 100644 --- a/mqtt/src/connect.rs +++ b/mqtt/src/connect.rs @@ -86,21 +86,54 @@ impl TryFrom for ConnectReasonCode { } } +#[repr(u8)] +enum ConnectionFlags { + CleanStart = 0x02, + Password = 0x40, + Username = 0x80, +} + +#[derive(Debug)] +pub struct LoginCredentials<'a> { + username: &'a str, + password: &'a str, +} + +impl<'a> LoginCredentials<'a> { + pub fn new(username: &'a str, password: &'a str) -> Self { + Self { username, password } + } +} + pub fn send_connect>( mut writer: Writer, client_id: &Option, + login_credentials: &Option, rx_max: u16, ) -> Result<(), HlError> { const KEEP_ALIVE: u16 = 15 * 60; - let client_id_len: u8 = client_id.map(|id| id.len()).unwrap_or(0); + let mut flags: u8 = ConnectionFlags::CleanStart as u8; + + let client_id = client_id.as_ref(); + let client_id_len = client_id.map_or(0, ClientId::len); + + let mut suffix_len: u8 = client_id_len; + + if let Some(login_credentials) = login_credentials { + flags += ConnectionFlags::Username as u8; + flags += ConnectionFlags::Password as u8; + + suffix_len += payload_len(login_credentials.username); + suffix_len += payload_len(login_credentials.password); + } #[rustfmt::skip] writer.write_all(&[ // control packet type (CtrlPkt::CONNECT as u8) << 4, // remaining length - 18 + client_id_len, + 18 + suffix_len, // protocol name length 0, 4, // protocol name @@ -108,12 +141,12 @@ pub fn send_connect>( // protocol version 5, // flags, clean start is set - 0b00000010, + flags, // keepalive (KEEP_ALIVE >> 8) as u8, KEEP_ALIVE as u8, // properties length 5, - // recieve maximum property + // receive maximum property (Properties::MaxPktSize as u8), 0, 0, (rx_max >> 8) as u8, rx_max as u8, // client ID length 0, client_id_len, @@ -121,6 +154,28 @@ pub fn send_connect>( if let Some(client_id) = client_id { writer.write_all(client_id.as_bytes())?; } + + if let Some(login_credentials) = login_credentials { + let LoginCredentials { username, password } = login_credentials; + writer.write_all(str_len_msb_lsb(username).as_slice())?; + writer.write_all(username.as_bytes())?; + writer.write_all(str_len_msb_lsb(password).as_slice())?; + writer.write_all(password.as_bytes())?; + } + writer.send()?; Ok(()) } + +fn payload_len(s: &str) -> u8 { + // str len + 2 bytes for str len prefix + (s.len() + 2) as u8 +} + +fn str_len_msb_lsb(s: &str) -> [u8; 2] { + let len: u16 = s.len() as u16; + let msb: u8 = (len >> 8) as u8; + let lsb: u8 = len as u8; + + [msb, lsb] +} diff --git a/mqtt/src/lib.rs b/mqtt/src/lib.rs index ec885d7c..90056b14 100644 --- a/mqtt/src/lib.rs +++ b/mqtt/src/lib.rs @@ -79,8 +79,8 @@ pub mod tls; pub use w5500_tls; pub use client_id::ClientId; -use connect::send_connect; pub use connect::ConnectReasonCode; +use connect::{send_connect, LoginCredentials}; use hl::{ io::{Read, Seek, Write}, ll::{net::SocketAddrV4, Registers, Sn, SocketInterrupt, SocketInterruptMask}, @@ -325,6 +325,8 @@ pub struct Client<'a> { state_timeout: StateTimeout, /// Packet ID for subscribing pkt_id: u16, + /// Login credentials + credentials: Option>, } impl<'a> Client<'a> { @@ -364,6 +366,7 @@ impl<'a> Client<'a> { }, client_id: None, pkt_id: 1, + credentials: None, } } @@ -399,6 +402,11 @@ impl<'a> Client<'a> { self.client_id = Some(client_id) } + /// Set the MQTT login credentials. + pub fn set_credentials(&mut self, username: &'a str, password: &'a str) { + self.credentials = Some(LoginCredentials::new(username, password)); + } + fn next_pkt_id(&mut self) -> u16 { self.pkt_id = self.pkt_id.checked_add(1).unwrap_or(1); self.pkt_id @@ -687,7 +695,8 @@ impl<'a> Client<'a> { .size_in_bytes() as u16; let writer: TcpWriter = w5500.tcp_writer(self.sn)?; - send_connect(writer, &self.client_id, rx_max).map_err(Error::map_w5500)?; + send_connect(writer, &self.client_id, &self.credentials, rx_max) + .map_err(Error::map_w5500)?; Ok(self .state_timeout .set_state_with_timeout(State::WaitConAck, monotonic_secs)) diff --git a/mqtt/src/tls.rs b/mqtt/src/tls.rs index 7ea889ad..1b2a72d3 100644 --- a/mqtt/src/tls.rs +++ b/mqtt/src/tls.rs @@ -54,7 +54,7 @@ //! [`w5500-tls`]: https://github.com/newAM/w5500-rs/blob/main/tls/README.md use crate::{ - connect::send_connect, + connect::{send_connect, LoginCredentials}, hl::{ ll::{net::SocketAddrV4, Registers, Sn}, Error as HlError, Hostname, @@ -88,16 +88,18 @@ fn map_tls_writer_err(e: w5500_tls::Error) -> Error { /// /// The methods are nearly identical to [`crate::Client`], see [`crate::Client`] /// for additional documentation and examples. -pub struct Client<'id, 'hn, 'psk, 'b, const N: usize> { +pub struct Client<'id, 'hn, 'psk, 'b, 'cred, const N: usize> { tls: TlsClient<'hn, 'psk, 'b, N>, client_id: Option>, /// State and Timeout tracker state_timeout: StateTimeout, /// Packet ID for subscribing pkt_id: u16, + /// Login credentials + credentials: Option>, } -impl<'id, 'hn, 'psk, 'b, const N: usize> Client<'id, 'hn, 'psk, 'b, N> { +impl<'id, 'hn, 'psk, 'b, 'cred, const N: usize> Client<'id, 'hn, 'psk, 'b, 'cred, N> { /// Create a new MQTT client. /// /// # Arguments @@ -151,8 +153,8 @@ impl<'id, 'hn, 'psk, 'b, const N: usize> Client<'id, 'hn, 'psk, 'b, N> { timeout: None, }, client_id: None, - pkt_id: 1, + credentials: None, } } @@ -161,6 +163,11 @@ impl<'id, 'hn, 'psk, 'b, const N: usize> Client<'id, 'hn, 'psk, 'b, N> { self.client_id = Some(client_id) } + /// Set the MQTT login credentials. + pub fn set_credentials(&mut self, username: &'cred str, password: &'cred str) { + self.credentials = Some(LoginCredentials::new(username, password)); + } + fn next_pkt_id(&mut self) -> u16 { self.pkt_id = self.pkt_id.checked_add(1).unwrap_or(1); self.pkt_id @@ -280,7 +287,8 @@ impl<'id, 'hn, 'psk, 'b, const N: usize> Client<'id, 'hn, 'psk, 'b, N> { let rx_max: u16 = (N as u16) - TLS_OVERHEAD; let writer: TlsWriter = self.tls.writer(w5500).map_err(map_tls_writer_err)?; - send_connect(writer, &self.client_id, rx_max).map_err(Error::map_w5500)?; + send_connect(writer, &self.client_id, &self.credentials, rx_max) + .map_err(Error::map_w5500)?; Ok(self .state_timeout .set_state_with_timeout(State::WaitConAck, monotonic_secs)) diff --git a/mqtt/tests/connect.rs b/mqtt/tests/connect.rs index 0c74b305..5045ff45 100644 --- a/mqtt/tests/connect.rs +++ b/mqtt/tests/connect.rs @@ -49,6 +49,47 @@ fn connect_no_client_id() { })); } +#[test] +fn connect_with_login() { + const PORT: u16 = 12345; + let mut client: Client = + Client::new(Sn0, SRC_PORT, SocketAddrV4::new(Ipv4Addr::LOCALHOST, PORT)); + client.set_credentials("mqtt-user", "password"); + + let mut fixture = Fixture::with_client(client, PORT); + assert!(matches!( + fixture.client_process().unwrap(), + Event::CallAfter(10) + )); + fixture.server.accept(); + assert!(matches!( + fixture.client_process().unwrap(), + Event::CallAfter(10) + )); + fixture.server_expect(Packet::Connect(Connect { + protocol: V5, + keep_alive: 900, + client_id: "".to_string(), + clean_session: true, + last_will: None, + login: Some(mqttbytes::v5::Login { + username: "mqtt-user".to_string(), + password: "password".to_string(), + }), + properties: Some(ConnectProperties { + session_expiry_interval: None, + receive_maximum: None, + max_packet_size: Some(2048), + topic_alias_max: None, + request_response_info: None, + request_problem_info: None, + user_properties: vec![], + authentication_method: None, + authentication_data: None, + }), + })); +} + #[test] fn connect_fail() { const PORT: u16 = 12344;