diff --git a/include/tins/icmp_extension.h b/include/tins/icmp_extension.h index e4a5c75..9731d79 100644 --- a/include/tins/icmp_extension.h +++ b/include/tins/icmp_extension.h @@ -2,7 +2,10 @@ #define TINS_ICMP_EXTENSION_H #include +#include #include +#include "small_uint.h" +#include "endianness.h" namespace Tins { @@ -11,7 +14,15 @@ namespace Tins { */ class ICMPExtension { public: + /** + * The type used to store the payload + */ typedef std::vector payload_type; + + /** + * The type that will be returned when serializing an extensions + * structure object + */ typedef std::vector serialization_type; /** @@ -50,7 +61,7 @@ public: * * \return The size of this extension */ - uint32_t extension_size() const; + uint32_t size() const; /** * \brief Serializes this extension into a buffer @@ -61,9 +72,9 @@ public: void serialize(uint8_t* buffer, uint32_t buffer_size) const; /** - * \brief Serializes this ICMP extension object + * \brief Serializes this extension object * - * \return The serialized ICMP extension + * \return The serialized extension */ serialization_type serialize() const; private: @@ -73,6 +84,119 @@ private: uint8_t extension_class_, extension_type_; }; +/** + * \brief Class that represents an ICMP extensions structure + */ +class ICMPExtensionsStructure { +public: + /** + * The type that will be returned when serializing an extensions + * structure object + */ + typedef ICMPExtension::serialization_type serialization_type; + + /** + * The type used to store the list of ICMP extensions in this structure + */ + typedef std::list extensions_type; + + /** + * \brief Default constructor + * + * This sets the version to 2, as specified in RFC 4884 + */ + ICMPExtensionsStructure(); + + /** + * \brief Constructor from a buffer. + * + * This constructor will find, parse and store the extension + * stack in the buffer. + */ + ICMPExtensionsStructure(const uint8_t* buffer, uint32_t total_sz); + + /** + * \brief Setter for the checksum field + * + * \param value The new reserved field value + */ + void reserved(small_uint<12> value); + + /** + * \brief Getter for the version field + * + * \return The version field value + */ + small_uint<4> version() const { + uint16_t value = Endian::be_to_host(version_and_reserved_); + return (value >> 12) & 0xf; + } + + /** + * \brief Getter for the reserved field + * + * \return The reserved field value + */ + small_uint<12> reserved() const { + uint16_t value = Endian::be_to_host(version_and_reserved_); + return value & 0xfff; + } + + /** + * \brief Getter for the checksum field + * + * \return The checksum field value + */ + uint16_t checksum() const { return Endian::be_to_host(checksum_); } + + /** + * \brief Getter for the extensions stored by this structure + * + * \return The extensions stored in this structure + */ + const extensions_type& extensions() const { return extensions_; } + + /** + * \brief Gets the size of this ICMP extensions structure + * + * \return The size of this structure + */ + uint32_t size() const; + + /** + * \brief Serializes this extension structure into a buffer + * + * \param buffer The output buffer in which to store the serialization + * \param buffer_size The size of the output buffer + */ + void serialize(uint8_t* buffer, uint32_t buffer_size); + + /** + * \brief Serializes this extension structure + * + * \return The serialized extension structure + */ + serialization_type serialize(); + + /** + * \brief Validates if the given input contains a valid extension structure + * + * The validation is performed by calculating the checksum of the input + * and comparing to the checksum value in the input buffer. + * + * \param buffer The input buffer + * \param total_sz The size of the input buffer + * \return true iff the buffer contains a valid ICMP extensions structure + */ + static bool validate_extensions(const uint8_t* buffer, uint32_t total_sz); +private: + static const uint32_t BASE_HEADER_SIZE; + + uint16_t version_and_reserved_; + uint16_t checksum_; + extensions_type extensions_; +}; + } // Tins #endif // TINS_ICMP_EXTENSION_H diff --git a/include/tins/utils.h b/include/tins/utils.h index 83877da..dbb84c9 100644 --- a/include/tins/utils.h +++ b/include/tins/utils.h @@ -215,17 +215,30 @@ namespace Tins { */ std::string to_string(PDU::PDUType pduType); - /** \brief Does the 16 bits sum of all 2 bytes elements between start and end. + /** + * \brief Does the 16 bits sum of all 2 bytes elements between start and end. * * This is the checksum used by IP, UDP and TCP. If there's and odd number of - * bytes, the last one is padded and added to the checksum. The checksum is performed - * using network endiannes. + * bytes, the last one is padded and added to the checksum. * \param start The pointer to the start of the buffer. * \param end The pointer to the end of the buffer(excluding the last element). - * \return Returns the checksum between start and end(non inclusive). + * \return Returns the checksum between start and end (non inclusive) + * in network endian */ uint32_t do_checksum(const uint8_t *start, const uint8_t *end); + /** + * \brief Computes the 16 bit sum of the input buffer. + * + * If there's and odd number of bytes in the buffer, the last one is padded and + * added to the checksum. + * \param start The pointer to the start of the buffer. + * \param end The pointer to the end of the buffer(excluding the last element). + * \return Returns the checksum between start and end (non inclusive) + * in network endian + */ + uint16_t sum_range(const uint8_t *start, const uint8_t *end); + /** \brief Performs the pseudo header checksum used in TCP and UDP PDUs. * * \param source_ip The source ip address. diff --git a/src/icmp_extension.cpp b/src/icmp_extension.cpp index 5cfa895..d81d1aa 100644 --- a/src/icmp_extension.cpp +++ b/src/icmp_extension.cpp @@ -1,7 +1,8 @@ #include +#include #include "icmp_extension.h" -#include "endianness.h" #include "exceptions.h" +#include "utils.h" using std::runtime_error; @@ -9,6 +10,8 @@ namespace Tins { const uint32_t ICMPExtension::BASE_HEADER_SIZE = sizeof(uint16_t) + sizeof(uint8_t) * 2; +// ICMPExtension class + ICMPExtension::ICMPExtension(const uint8_t* buffer, uint32_t total_sz) { // Check for the base header (u16 length + u8 clss + u8 type) if (total_sz < BASE_HEADER_SIZE) { @@ -28,15 +31,15 @@ ICMPExtension::ICMPExtension(const uint8_t* buffer, uint32_t total_sz) { payload_.assign(buffer, buffer + length); } -uint32_t ICMPExtension::extension_size() const { +uint32_t ICMPExtension::size() const { return BASE_HEADER_SIZE + payload_.size(); } void ICMPExtension::serialize(uint8_t* buffer, uint32_t buffer_size) const { - if (buffer_size < extension_size()) { + if (buffer_size < size()) { throw runtime_error("Serialization buffer is too small"); } - *(uint16_t*)buffer = Endian::host_to_be(extension_size()); + *(uint16_t*)buffer = Endian::host_to_be(size()); buffer += sizeof(uint16_t); *buffer = extension_class_; buffer += sizeof(uint8_t); @@ -46,7 +49,94 @@ void ICMPExtension::serialize(uint8_t* buffer, uint32_t buffer_size) const { } ICMPExtension::serialization_type ICMPExtension::serialize() const { - serialization_type output(extension_size()); + serialization_type output(size()); + serialize(&output[0], output.size()); + return output; +} + +// ICMPExtensionsStructure class + +const uint32_t ICMPExtensionsStructure::BASE_HEADER_SIZE = sizeof(uint16_t) * 2; + +ICMPExtensionsStructure::ICMPExtensionsStructure() +: version_and_reserved_(0x2000), checksum_(0) { + +} + +ICMPExtensionsStructure::ICMPExtensionsStructure(const uint8_t* buffer, uint32_t total_sz) { + if (total_sz < BASE_HEADER_SIZE) { + throw malformed_packet(); + } + + version_and_reserved_ = *(const uint16_t*)buffer; + buffer += sizeof(uint16_t); + checksum_ = *(const uint16_t*)buffer; + buffer += sizeof(uint16_t); + total_sz -= BASE_HEADER_SIZE; + while (total_sz > 0) { + extensions_.push_back(ICMPExtension(buffer, total_sz)); + uint16_t size = Endian::be_to_host(*(const uint16_t*)buffer); + total_sz -= size; + } +} + +void ICMPExtensionsStructure::reserved(small_uint<12> value) { + uint16_t current_value = version_and_reserved_; + current_value &= 0xf000; + current_value |= value; + version_and_reserved_ = Endian::host_to_be(current_value); +} + +bool ICMPExtensionsStructure::validate_extensions(const uint8_t* buffer, uint32_t total_sz) { + if (total_sz < BASE_HEADER_SIZE) { + return false; + } + uint16_t checksum = *(const uint16_t*)(buffer + sizeof(uint16_t)); + // The buffer is read only, so we can't set the initial checksum to 0. Therefore, + // we sum the first 2 bytes and then the payload + uint32_t actual_checksum = *(const uint16_t*)buffer; + buffer += BASE_HEADER_SIZE; + total_sz -= BASE_HEADER_SIZE; + // Now do the checksum over the payload + actual_checksum += Utils::sum_range(buffer, buffer + total_sz); + return checksum == static_cast(~actual_checksum); +} + +uint32_t ICMPExtensionsStructure::size() const { + typedef extensions_type::const_iterator iterator; + uint32_t output = BASE_HEADER_SIZE; + for (iterator iter = extensions_.begin(); iter != extensions_.end(); ++iter) { + output += iter->size(); + } + return output; +} + +void ICMPExtensionsStructure::serialize(uint8_t* buffer, uint32_t buffer_size) { + const uint32_t structure_size = size(); + if (buffer_size < structure_size) { + throw malformed_packet(); + } + uint8_t* original_ptr = buffer; + memcpy(buffer, &version_and_reserved_, sizeof(version_and_reserved_)); + buffer += sizeof(uint16_t); + // Make checksum 0, for now, we'll compute it at the end + memset(buffer, 0, sizeof(uint16_t)); + buffer += sizeof(uint16_t); + buffer_size -= BASE_HEADER_SIZE; + + typedef extensions_type::const_iterator iterator; + for (iterator iter = extensions_.begin(); iter != extensions_.end(); ++iter) { + iter->serialize(buffer, buffer_size); + buffer += iter->size(); + buffer_size -= iter->size(); + } + uint16_t checksum = ~Utils::sum_range(original_ptr, original_ptr + structure_size); + memcpy(original_ptr + sizeof(uint16_t), &checksum, sizeof(checksum)); + checksum_ = checksum; +} + +ICMPExtensionsStructure::serialization_type ICMPExtensionsStructure::serialize() { + serialization_type output(size()); serialize(&output[0], output.size()); return output; } diff --git a/src/utils.cpp b/src/utils.cpp index 5d26c65..274bb91 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -221,6 +221,10 @@ std::string to_string(PDU::PDUType pduType) { } uint32_t do_checksum(const uint8_t *start, const uint8_t *end) { + return Endian::host_to_be(sum_range(start, end)); +} + +uint16_t sum_range(const uint8_t *start, const uint8_t *end) { uint32_t checksum(0); const uint8_t *last = end; uint16_t buffer = 0; @@ -229,16 +233,20 @@ uint32_t do_checksum(const uint8_t *start, const uint8_t *end) { if(((end - start) & 1) == 1) { last = end - 1; - padding = *(end - 1) << 8; + padding = *(end - 1); } while(ptr < last) { memcpy(&buffer, ptr, sizeof(uint16_t)); - checksum += Endian::host_to_be(buffer); + checksum += buffer; ptr += sizeof(uint16_t); } - return checksum + padding; + checksum += padding; + while (checksum >> 16) { + checksum = (checksum & 0xffff) + (checksum >> 16); + } + return checksum; } uint32_t pseudoheader_checksum(IPv4Address source_ip, IPv4Address dest_ip, uint32_t len, uint32_t flag) { diff --git a/tests/src/icmp_extension.cpp b/tests/src/icmp_extension.cpp index da118bf..b7b4e80 100644 --- a/tests/src/icmp_extension.cpp +++ b/tests/src/icmp_extension.cpp @@ -2,6 +2,7 @@ #include "icmp_extension.h" using Tins::ICMPExtension; +using Tins::ICMPExtensionsStructure; class ICMPExtensionTest : public testing::Test { public: @@ -25,3 +26,40 @@ TEST_F(ICMPExtensionTest, ConstructorFromBuffer) { buffer ); } + +TEST_F(ICMPExtensionTest, ExtensionStructureValidation) { + const uint8_t input[] = { 32, 0, 197, 95, 0, 8, 1, 1, 24, 150, 1, 1 }; + EXPECT_TRUE(ICMPExtensionsStructure::validate_extensions(input, sizeof(input))); +} + +TEST_F(ICMPExtensionTest, ExtensionStructureFromBuffer) { + const uint8_t input[] = { 32, 0, 197, 95, 0, 8, 1, 1, 24, 150, 1, 1 }; + ICMPExtensionsStructure structure(input, sizeof(input)); + EXPECT_EQ(2, structure.version()); + EXPECT_EQ(0, structure.reserved()); + EXPECT_EQ(0xc55f, structure.checksum()); + const ICMPExtensionsStructure::extensions_type& extensions = structure.extensions(); + EXPECT_EQ(1, extensions.size()); + const ICMPExtension& ext = *extensions.begin(); + + const uint8_t payload[] = { 24, 150, 1, 1 }; + EXPECT_EQ(1, ext.extension_class()); + EXPECT_EQ(1, ext.extension_type()); + EXPECT_EQ( + ICMPExtension::payload_type(payload, payload + sizeof(payload)), + ext.payload() + ); + + ICMPExtension::serialization_type buffer = structure.serialize(); + EXPECT_EQ( + ICMPExtension::serialization_type(input, input + sizeof(input)), + buffer + ); +} + +TEST_F(ICMPExtensionTest, Reserved) { + ICMPExtensionsStructure structure; + structure.reserved(0xdea); + EXPECT_EQ(0xdea, structure.reserved()); + EXPECT_EQ(2, structure.version()); +}