diff --git a/src/eapol.cpp b/src/eapol.cpp index 542e29c..22a453c 100644 --- a/src/eapol.cpp +++ b/src/eapol.cpp @@ -33,6 +33,7 @@ #include "eapol.h" #include "rsn_information.h" #include "exceptions.h" +#include "rawpdu.h" #include "memory_helpers.h" using Tins::Memory::InputMemoryStream; @@ -95,6 +96,7 @@ void EAPOL::type(uint8_t new_type) { void EAPOL::write_serialization(uint8_t *buffer, uint32_t total_sz, const PDU *) { OutputMemoryStream stream(buffer, total_sz); + length(total_sz - 4); stream.write(_header); std::memcpy(buffer, &_header, sizeof(_header)); write_body(stream); @@ -111,15 +113,16 @@ RC4EAPOL::RC4EAPOL() RC4EAPOL::RC4EAPOL(const uint8_t *buffer, uint32_t total_sz) : EAPOL(buffer, total_sz) { - buffer += sizeof(eapolhdr); - total_sz -= sizeof(eapolhdr); - if(total_sz < sizeof(_header)) - throw malformed_packet(); - std::memcpy(&_header, buffer, sizeof(_header)); - buffer += sizeof(_header); - total_sz -= sizeof(_header); - if(total_sz == key_length()) - _key.assign(buffer, buffer + total_sz); + InputMemoryStream stream(buffer, total_sz); + stream.skip(sizeof(eapolhdr)); + stream.read(_header); + if (stream.size() >= key_length()) { + _key.assign(stream.pointer(), stream.pointer() + key_length()); + stream.skip(key_length()); + if (stream) { + inner_pdu(new RawPDU(stream.pointer(), stream.size())); + } + } } void RC4EAPOL::key_length(uint16_t new_key_length) { @@ -174,15 +177,16 @@ RSNEAPOL::RSNEAPOL() RSNEAPOL::RSNEAPOL(const uint8_t *buffer, uint32_t total_sz) : EAPOL(buffer, total_sz) { - buffer += sizeof(eapolhdr); - total_sz -= sizeof(eapolhdr); - if(total_sz < sizeof(_header)) - throw malformed_packet(); - std::memcpy(&_header, buffer, sizeof(_header)); - buffer += sizeof(_header); - total_sz -= sizeof(_header); - if(total_sz == wpa_length()) - _key.assign(buffer, buffer + total_sz); + InputMemoryStream stream(buffer, total_sz); + stream.skip(sizeof(eapolhdr)); + stream.read(_header); + if (stream.size() >= wpa_length()) { + _key.assign(stream.pointer(), stream.pointer() + wpa_length()); + stream.skip(wpa_length()); + if (stream) { + inner_pdu(new RawPDU(stream.pointer(), stream.size())); + } + } } void RSNEAPOL::nonce(const uint8_t *new_nonce) { @@ -268,7 +272,7 @@ uint32_t RSNEAPOL::header_size() const { void RSNEAPOL::write_body(OutputMemoryStream& stream) { if (_key.size()) { - if (!_header.key_t) { + if (!_header.key_t && _header.install) { _header.key_length = Endian::host_to_be(32); wpa_length(static_cast(_key.size())); } diff --git a/tests/src/rsn_eapol.cpp b/tests/src/rsn_eapol.cpp index 5468fc1..af94b73 100644 --- a/tests/src/rsn_eapol.cpp +++ b/tests/src/rsn_eapol.cpp @@ -5,6 +5,7 @@ #include "eapol.h" #include "snap.h" #include "utils.h" +#include "ethernetII.h" #include "rsn_information.h" using namespace std; @@ -14,6 +15,7 @@ class RSNEAPOLTest : public testing::Test { public: static const uint8_t expected_packet[]; static const uint8_t eapol_over_snap[]; + static const uint8_t broken_eapol[]; void test_equals(const RSNEAPOL &eapol1, const RSNEAPOL &eapol2); }; @@ -65,6 +67,18 @@ const uint8_t RSNEAPOLTest::eapol_over_snap[] = { 123, 212, 159 }; +const uint8_t RSNEAPOLTest::broken_eapol[] = { + 44, 240, 238, 33, 128, 46, 72, 248, 179, 139, 32, 112, 136, 142, 2, + 3, 0, 127, 2, 19, 130, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 231, 103, 200, 107, 89, 185, 187, 51, 27, 32, 91, 65, 95, + 165, 127, 37, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 126, + 159, 123, 33, 66, 3, 254, 124, 6, 192, 129, 143, 215, 59, 38, 162, + 0, 24, 221, 22, 0, 15, 172, 1, 1, 0, 237, 214, 169, 68, 84, 98, 24, + 182, 8, 221, 81, 125, 222, 224, 243, 97, 229, 99, 186, 225, 196, 225, + 179, 86 +}; + void RSNEAPOLTest::test_equals(const RSNEAPOL &eapol1, const RSNEAPOL &eapol2) { EXPECT_EQ(eapol1.version(), eapol2.version()); EXPECT_EQ(eapol1.packet_type(), eapol2.packet_type()); @@ -143,6 +157,19 @@ TEST_F(RSNEAPOLTest, Serialize) { EXPECT_TRUE(std::equal(buffer.begin(), buffer.end(), expected_packet)); } +// This is a test for a packet for which the serialization lacked the WPA key. +// This packet contains a misterious 8 byte field that I can't seem to find +// on the standard. Wireshark doesn't understand it either. This will currently +// be appended as a RawPDU at the end. +TEST_F(RSNEAPOLTest, SerializeBrokenEapol) { + EthernetII eapol(broken_eapol, sizeof(broken_eapol)); + RSNEAPOL::serialization_type buffer = eapol.serialize(); + EXPECT_EQ( + RSNEAPOL::serialization_type(broken_eapol, broken_eapol + sizeof(broken_eapol)), + buffer + ); +} + TEST_F(RSNEAPOLTest, ConstructionTest) { RSNEAPOL eapol; eapol.version(1);