From 13c05fbdb17c5414237f989f4763677f18c8e019 Mon Sep 17 00:00:00 2001 From: Matias Fontanini Date: Thu, 24 Dec 2015 15:21:07 -0800 Subject: [PATCH] Add input memory stream class and port some PDUs to use it --- include/tins/icmp.h | 5 +- include/tins/icmpv6.h | 8 ++- include/tins/internals.h | 5 +- include/tins/macros.h | 12 +++- include/tins/memory_helpers.h | 101 +++++++++++++++++++++++++++++++ src/arp.cpp | 20 +++---- src/bootp.cpp | 16 ++--- src/dhcp.cpp | 48 ++++++++------- src/dhcpv6.cpp | 61 +++++++++---------- src/dns.cpp | 18 +++--- src/dot1q.cpp | 16 ++--- src/dot3.cpp | 16 ++--- src/eapol.cpp | 14 +++-- src/ethernetII.cpp | 22 +++---- src/icmp.cpp | 41 +++++-------- src/icmp_extension.cpp | 39 +++++------- src/icmpv6.cpp | 69 +++++++++------------- src/internals.cpp | 24 ++++---- src/ip.cpp | 108 +++++++++++++++++++--------------- src/tcp.cpp | 60 +++++++++++-------- 20 files changed, 408 insertions(+), 295 deletions(-) create mode 100644 include/tins/memory_helpers.h diff --git a/include/tins/icmp.h b/include/tins/icmp.h index 33f7c7f..6c0f1ac 100644 --- a/include/tins/icmp.h +++ b/include/tins/icmp.h @@ -49,6 +49,9 @@ #include "icmp_extension.h" namespace Tins { +namespace Memory { +class InputMemoryStream; +} // memory /** * \class ICMP @@ -460,7 +463,7 @@ private: void write_serialization(uint8_t *buffer, uint32_t total_sz, const PDU *parent); uint32_t get_adjusted_inner_pdu_size() const; - void try_parse_extensions(const uint8_t* buffer, uint32_t& total_sz); + void try_parse_extensions(Memory::InputMemoryStream& stream); bool are_extensions_allowed() const; icmphdr _icmp; diff --git a/include/tins/icmpv6.h b/include/tins/icmpv6.h index 40e4364..c7d24d8 100644 --- a/include/tins/icmpv6.h +++ b/include/tins/icmpv6.h @@ -44,6 +44,10 @@ #include "cxxstd.h" namespace Tins { +namespace Memory { +class InputMemoryStream; +} // memory + /** * \class ICMPv6 * \brief Represents an ICMPv6 PDU. @@ -1376,12 +1380,12 @@ private: void write_serialization(uint8_t *buffer, uint32_t total_sz, const PDU *parent); bool has_options() const; uint8_t *write_option(const option &opt, uint8_t *buffer); - void parse_options(const uint8_t *&buffer, uint32_t &total_sz); + void parse_options(Memory::InputMemoryStream& stream); void add_addr_list(uint8_t type, const addr_list_type &value); addr_list_type search_addr_list(OptionTypes type) const; options_type::const_iterator search_option_iterator(OptionTypes type) const; options_type::iterator search_option_iterator(OptionTypes type); - void try_parse_extensions(const uint8_t* buffer, uint32_t& total_sz); + void try_parse_extensions(Memory::InputMemoryStream& stream); bool are_extensions_allowed() const; uint32_t get_adjusted_inner_pdu_size() const; diff --git a/include/tins/internals.h b/include/tins/internals.h index 283fc88..9b71c8b 100644 --- a/include/tins/internals.h +++ b/include/tins/internals.h @@ -45,6 +45,9 @@ * \cond */ namespace Tins { +namespace Memory { +class InputMemoryStream; +} // Memory class IPv4Address; class IPv6Address; class ICMPExtensionsStructure; @@ -126,7 +129,7 @@ Constants::Ethernet::e pdu_flag_to_ether_type(PDU::PDUType flag); Constants::IP::e pdu_flag_to_ip_type(PDU::PDUType flag); uint32_t get_padded_icmp_inner_pdu_size(const PDU* inner_pdu, uint32_t pad_alignment); -void try_parse_icmp_extensions(const uint8_t* buffer, uint32_t& total_sz, +void try_parse_icmp_extensions(Memory::InputMemoryStream& stream, uint32_t payload_length, ICMPExtensionsStructure& extensions); template diff --git a/include/tins/macros.h b/include/tins/macros.h index a22321c..ac330a4 100644 --- a/include/tins/macros.h +++ b/include/tins/macros.h @@ -34,19 +34,25 @@ #include #endif -// Packing directives.... +// Check if this is Visual Studio #ifdef _MSC_VER + // This is Visual Studio #define TINS_BEGIN_PACK __pragma( pack(push, 1) ) #define TINS_END_PACK __pragma( pack(pop) ) #define TINS_PACKED(DECLARATION) __pragma( pack(push, 1) ) DECLARATION __pragma( pack(pop) ) #define TINS_DEPRECATED(func) __declspec(deprecated) func #define TINS_NOEXCEPT + #define TINS_LIKELY(x) (x) + #define TINS_UNLIKELY(x) (x) #else + // Not Vistual Studio. Assume this is gcc compatible #define TINS_BEGIN_PACK #define TINS_END_PACK __attribute__((packed)) #define TINS_PACKED(DECLARATION) DECLARATION __attribute__((packed)) #define TINS_DEPRECATED(func) func __attribute__ ((deprecated)) #define TINS_NOEXCEPT noexcept -#endif + #define TINS_LIKELY(x) __builtin_expect((x),1) + #define TINS_UNLIKELY(x) __builtin_expect((x),0) +#endif // -#endif +#endif // TINS_MACROS_H diff --git a/include/tins/memory_helpers.h b/include/tins/memory_helpers.h new file mode 100644 index 0000000..d705482 --- /dev/null +++ b/include/tins/memory_helpers.h @@ -0,0 +1,101 @@ +#ifndef TINS_MEMORY_HELPERS_H +#define TINS_MEMORY_HELPERS_H + +#include +#include +#include "exceptions.h" +#include "ip_address.h" +#include "ipv6_address.h" + +namespace Tins { +namespace Memory { + +inline void read_data(const uint8_t* buffer, uint8_t* output_buffer, uint32_t size) { + std::memcpy(output_buffer, buffer, size); +} + +template +void read_value(const uint8_t* buffer, T& value) { + std::memcpy(&value, buffer, sizeof(value)); +} + +template +void write_value(uint8_t* buffer, const T& value) { + std::memcpy(buffer, &value, sizeof(value)); +} + +class InputMemoryStream { +public: + InputMemoryStream(const uint8_t* buffer, uint32_t total_sz) + : buffer_(buffer), size_(total_sz) { + } + + void skip(uint32_t size) { + buffer_ += size; + size_ -= size; + } + + bool can_read(uint32_t byte_count) const { + return TINS_LIKELY(size_ >= byte_count); + } + + template + T read() { + T output; + read(output); + return output; + } + + template + void read(T& value) { + if (!can_read(sizeof(value))) { + throw malformed_packet(); + } + read_value(buffer_, value); + skip(sizeof(value)); + } + + void read(IPv4Address& address) { + address = IPv4Address(read()); + } + + void read(IPv6Address& address) { + if (!can_read(IPv6Address::address_size)) { + throw malformed_packet(); + } + address = pointer(); + skip(IPv6Address::address_size); + } + + void read(void* output_buffer, uint32_t output_buffer_size) { + if (!can_read(output_buffer_size)) { + throw malformed_packet(); + } + read_data(buffer_, (uint8_t*)output_buffer, output_buffer_size); + skip(output_buffer_size); + } + + const uint8_t* pointer() const { + return buffer_; + } + + uint32_t size() const { + return size_; + } + + void size(uint32_t new_size) { + size_ = new_size; + } + + operator bool() const { + return size_ > 0; + } +private: + const uint8_t* buffer_; + uint32_t size_; +}; + +} // Memory +} // Tins + +#endif // TINS_MEMORY_HELPERS_H diff --git a/src/arp.cpp b/src/arp.cpp index 970f4b0..fe9f54f 100644 --- a/src/arp.cpp +++ b/src/arp.cpp @@ -37,16 +37,18 @@ #include "constants.h" #include "network_interface.h" #include "exceptions.h" - +#include "memory_helpers.h" using std::runtime_error; +using Tins::Memory::InputMemoryStream; + namespace Tins { ARP::ARP(ipaddress_type target_ip, ipaddress_type sender_ip, - const hwaddress_type &target_hw, const hwaddress_type &sender_hw) +const hwaddress_type &target_hw, const hwaddress_type &sender_hw) +: _arp() { - memset(&_arp, 0, sizeof(arphdr)); hw_addr_format((uint16_t)Constants::ARP::ETHER); prot_addr_format((uint16_t)Constants::Ethernet::IP); hw_addr_length(Tins::EthernetII::address_type::address_size); @@ -59,13 +61,11 @@ ARP::ARP(ipaddress_type target_ip, ipaddress_type sender_ip, ARP::ARP(const uint8_t *buffer, uint32_t total_sz) { - if(total_sz < sizeof(arphdr)) - throw malformed_packet(); - memcpy(&_arp, buffer, sizeof(arphdr)); - total_sz -= sizeof(arphdr); - //TODO: Check whether this should be removed or not. - if(total_sz) - inner_pdu(new RawPDU(buffer + sizeof(arphdr), total_sz)); + InputMemoryStream stream(buffer, total_sz); + stream.read(_arp); + if (stream) { + inner_pdu(new RawPDU(stream.pointer(), stream.size())); + } } void ARP::sender_hw_addr(const hwaddress_type &new_snd_hw_addr) { diff --git a/src/bootp.cpp b/src/bootp.cpp index 6b5510d..10d31c9 100644 --- a/src/bootp.cpp +++ b/src/bootp.cpp @@ -32,23 +32,25 @@ #include #include "bootp.h" #include "exceptions.h" +#include "memory_helpers.h" + +using Tins::Memory::InputMemoryStream; namespace Tins{ + BootP::BootP() -: _vend(64) { - std::memset(&_bootp, 0, sizeof(bootphdr)); +: _bootp(), _vend(64) { + } BootP::BootP(const uint8_t *buffer, uint32_t total_sz, uint32_t vend_field_size) : _vend(vend_field_size) { - if(total_sz < sizeof(bootphdr) + vend_field_size) - throw malformed_packet(); - std::memcpy(&_bootp, buffer, sizeof(bootphdr)); + InputMemoryStream stream(buffer, total_sz); + stream.read(_bootp); buffer += sizeof(bootphdr); total_sz -= sizeof(bootphdr); - _vend.assign(buffer, buffer + vend_field_size); - // Maybe RawPDU on what is left on the buffer?... + _vend.assign(stream.pointer(), stream.pointer() + vend_field_size); } uint32_t BootP::header_size() const { diff --git a/src/dhcp.cpp b/src/dhcp.cpp index 8fb7c4e..934cc86 100644 --- a/src/dhcp.cpp +++ b/src/dhcp.cpp @@ -35,15 +35,20 @@ #include "ethernetII.h" #include "internals.h" #include "exceptions.h" +#include "memory_helpers.h" using std::string; using std::list; using std::runtime_error; using std::find_if; +using Tins::Memory::InputMemoryStream; + namespace Tins { + // Magic cookie: uint32_t. -DHCP::DHCP() : _size(sizeof(uint32_t)) { +DHCP::DHCP() +: _size(sizeof(uint32_t)) { opcode(BOOTREQUEST); htype(1); //ethernet hlen(EthernetII::address_type::address_size); @@ -52,33 +57,26 @@ DHCP::DHCP() : _size(sizeof(uint32_t)) { DHCP::DHCP(const uint8_t *buffer, uint32_t total_sz) : BootP(buffer, total_sz, 0), _size(sizeof(uint32_t)) { - buffer += BootP::header_size() - vend().size(); - total_sz -= static_cast(BootP::header_size() - vend().size()); - uint8_t args[2] = {0}; - uint32_t uint32_t_buffer; - std::memcpy(&uint32_t_buffer, buffer, sizeof(uint32_t)); - if(total_sz < sizeof(uint32_t) || uint32_t_buffer != Endian::host_to_be(0x63825363)) + const uint32_t bootp_size = BootP::header_size() - vend().size(); + InputMemoryStream stream(buffer + bootp_size, total_sz - bootp_size); + const uint32_t magic_number = stream.read(); + if (magic_number != Endian::host_to_be(0x63825363)) throw malformed_packet(); - buffer += sizeof(uint32_t); - total_sz -= sizeof(uint32_t); - while(total_sz) { - for(unsigned i(0); i < 2; ++i) { - args[i] = *(buffer++); - total_sz--; - if(args[0] == END || args[0] == PAD) { - args[1] = 0; - i = 2; - } - else if(!total_sz) - throw malformed_packet(); + // While there's data left + while (stream) { + OptionTypes option_type; + uint8_t option_length = 0; + option_type = (OptionTypes)stream.read(); + // We should only read the length if it's not END nor PAD + if (option_type != END && option_type != PAD) { + option_length = stream.read(); } - if(total_sz < args[1]) + // Make sure we can read the payload size + if (!stream.can_read(option_length)) { throw malformed_packet(); - add_option( - option((OptionTypes)args[0], args[1], buffer) - ); - buffer += args[1]; - total_sz -= args[1]; + } + add_option(option(option_type, option_length, stream.pointer())); + stream.skip(option_length); } } diff --git a/src/dhcpv6.cpp b/src/dhcpv6.cpp index 2109dba..9b22276 100644 --- a/src/dhcpv6.cpp +++ b/src/dhcpv6.cpp @@ -31,53 +31,50 @@ #include #include "dhcpv6.h" #include "exceptions.h" +#include "memory_helpers.h" using std::find_if; +using Tins::Memory::InputMemoryStream; + namespace Tins { -DHCPv6::DHCPv6() : options_size() { - std::fill(header_data, header_data + sizeof(header_data), 0); + +DHCPv6::DHCPv6() +: header_data(), options_size() { + } DHCPv6::DHCPv6(const uint8_t *buffer, uint32_t total_sz) : options_size() { - if(total_sz == 0) + InputMemoryStream stream(buffer, total_sz); + if (!stream) { throw malformed_packet(); - // Relay Agent/Server Messages - bool is_relay_msg = (buffer[0] == 12 || buffer[0] == 13); - uint32_t required_size = is_relay_msg ? 2 : 4; - if(total_sz < required_size) - throw malformed_packet(); - std::copy(buffer, buffer + required_size, header_data); - buffer += required_size; - total_sz -= required_size; - if(is_relay_message()) { - if(total_sz < ipaddress_type::address_size * 2) - throw malformed_packet(); - link_addr = buffer; - peer_addr = buffer + ipaddress_type::address_size; - buffer += ipaddress_type::address_size * 2; - total_sz -= ipaddress_type::address_size * 2; } - while(total_sz) { - if(total_sz < sizeof(uint16_t) * 2) + // Relay Agent/Server Messages + const MessageType message_type = (MessageType)*stream.pointer(); + bool is_relay_msg = (message_type == RELAY_FORWARD || message_type == RELAY_REPLY); + uint32_t required_size = is_relay_msg ? 2 : 4; + stream.read(&header_data, required_size); + if (is_relay_message()) { + if (!stream.can_read(ipaddress_type::address_size * 2)) { throw malformed_packet(); - - uint16_t opt; - std::memcpy(&opt, buffer, sizeof(uint16_t)); - opt = Endian::be_to_host(opt); - uint16_t data_size; - std::memcpy(&data_size, buffer + sizeof(uint16_t), sizeof(uint16_t)); - data_size = Endian::be_to_host(data_size); - if(total_sz - sizeof(uint16_t) * 2 < data_size) + } + // Read both addresses + link_addr = stream.pointer(); + peer_addr = stream.pointer() + ipaddress_type::address_size; + stream.skip(ipaddress_type::address_size * 2); + } + while (stream) { + uint16_t opt = Endian::be_to_host(stream.read()); + uint16_t data_size = Endian::be_to_host(stream.read()); + if(!stream.can_read(data_size)) { throw malformed_packet(); - buffer += sizeof(uint16_t) * 2; + } add_option( - option(opt, buffer, buffer + data_size) + option(opt, stream.pointer(), stream.pointer() + data_size) ); - buffer += data_size; - total_sz -= sizeof(uint16_t) * 2 + data_size; + stream.skip(data_size); } } diff --git a/src/dns.cpp b/src/dns.cpp index 849ed84..780f0d5 100644 --- a/src/dns.cpp +++ b/src/dns.cpp @@ -39,30 +39,28 @@ #include "exceptions.h" #include "rawpdu.h" #include "endianness.h" +#include "memory_helpers.h" using std::string; using std::list; +using Tins::Memory::InputMemoryStream; + namespace Tins { DNS::DNS() -: answers_idx(), authority_idx(), additional_idx() +: dns(), answers_idx(), authority_idx(), additional_idx() { - std::memset(&dns, 0, sizeof(dns)); } DNS::DNS(const uint8_t *buffer, uint32_t total_sz) : answers_idx(), authority_idx(), additional_idx() { - if(total_sz < sizeof(dnshdr)) - throw malformed_packet(); - std::memcpy(&dns, buffer, sizeof(dnshdr)); - records_data.assign( - buffer + sizeof(dnshdr), - buffer + total_sz - ); + InputMemoryStream stream(buffer, total_sz); + stream.read(dns); + records_data.assign(stream.pointer(), stream.pointer() + stream.size()); // Avoid doing this if there's no data. Otherwise VS's asserts fail. - if(!records_data.empty()) { + if (!records_data.empty()) { buffer = &records_data[0]; const uint8_t *end = &records_data[0] + records_data.size(), *prev_start = buffer; uint16_t nquestions = questions_count(); diff --git a/src/dot1q.cpp b/src/dot1q.cpp index 528030d..6fc310e 100644 --- a/src/dot1q.cpp +++ b/src/dot1q.cpp @@ -33,6 +33,9 @@ #include "dot1q.h" #include "internals.h" #include "exceptions.h" +#include "memory_helpers.h" + +using Tins::Memory::InputMemoryStream; namespace Tins { @@ -45,18 +48,15 @@ Dot1Q::Dot1Q(small_uint<12> tag_id, bool append_pad) Dot1Q::Dot1Q(const uint8_t *buffer, uint32_t total_sz) : _append_padding() { - if(total_sz < sizeof(_header)) - throw malformed_packet(); - std::memcpy(&_header, buffer, sizeof(_header)); - buffer += sizeof(_header); - total_sz -= sizeof(_header); + InputMemoryStream stream(buffer, total_sz); + stream.read(_header); - if(total_sz) { + if (stream) { inner_pdu( Internals::pdu_from_flag( (Constants::Ethernet::e)payload_type(), - buffer, - total_sz + stream.pointer(), + stream.size() ) ); } diff --git a/src/dot3.cpp b/src/dot3.cpp index 07a3ca4..a4f8c82 100644 --- a/src/dot3.cpp +++ b/src/dot3.cpp @@ -47,8 +47,12 @@ #include "packet_sender.h" #include "llc.h" #include "exceptions.h" +#include "memory_helpers.h" + +using Tins::Memory::InputMemoryStream; namespace Tins { + const Dot3::address_type Dot3::BROADCAST("ff:ff:ff:ff:ff:ff"); Dot3::Dot3(const address_type &dst_hw_addr, const address_type &src_hw_addr) @@ -62,13 +66,11 @@ Dot3::Dot3(const address_type &dst_hw_addr, const address_type &src_hw_addr) Dot3::Dot3(const uint8_t *buffer, uint32_t total_sz) { - if(total_sz < sizeof(ethhdr)) - throw malformed_packet(); - memcpy(&_eth, buffer, sizeof(ethhdr)); - buffer += sizeof(ethhdr); - total_sz -= sizeof(ethhdr); - if(total_sz) - inner_pdu(new Tins::LLC(buffer, total_sz)); + InputMemoryStream stream(buffer, total_sz); + stream.read(_eth); + if (stream) { + inner_pdu(new Tins::LLC(stream.pointer(), stream.size())); + } } void Dot3::dst_addr(const address_type &new_dst_mac) { diff --git a/src/eapol.cpp b/src/eapol.cpp index 340989a..e4c00c8 100644 --- a/src/eapol.cpp +++ b/src/eapol.cpp @@ -36,11 +36,15 @@ #include "eapol.h" #include "rsn_information.h" #include "exceptions.h" +#include "memory_helpers.h" + +using Tins::Memory::InputMemoryStream; namespace Tins { + EAPOL::EAPOL(uint8_t packet_type, EAPOLTYPE type) +: _header() { - std::memset(&_header, 0, sizeof(_header)); _header.version = 1; _header.packet_type = packet_type; _header.type = (uint8_t)type; @@ -48,14 +52,14 @@ EAPOL::EAPOL(uint8_t packet_type, EAPOLTYPE type) EAPOL::EAPOL(const uint8_t *buffer, uint32_t total_sz) { - if(total_sz < sizeof(_header)) - throw malformed_packet(); - std::memcpy(&_header, buffer, sizeof(_header)); + InputMemoryStream stream(buffer, total_sz); + stream.read(_header); } EAPOL *EAPOL::from_bytes(const uint8_t *buffer, uint32_t total_sz) { - if(total_sz < sizeof(eapolhdr)) + if (total_sz < sizeof(eapolhdr)) { throw malformed_packet(); + } const eapolhdr *ptr = (const eapolhdr*)buffer; uint32_t data_len = Endian::be_to_host(ptr->length); // at least 4 for fields always present diff --git a/src/ethernetII.cpp b/src/ethernetII.cpp index 5f824f4..152016c 100644 --- a/src/ethernetII.cpp +++ b/src/ethernetII.cpp @@ -53,33 +53,33 @@ #include "constants.h" #include "internals.h" #include "exceptions.h" +#include "memory_helpers.h" + +using Tins::Memory::InputMemoryStream; namespace Tins { + const EthernetII::address_type EthernetII::BROADCAST("ff:ff:ff:ff:ff:ff"); EthernetII::EthernetII(const address_type &dst_hw_addr, const address_type &src_hw_addr) +: _eth() { - memset(&_eth, 0, sizeof(ethhdr)); dst_addr(dst_hw_addr); src_addr(src_hw_addr); - _eth.payload_type = 0; - } EthernetII::EthernetII(const uint8_t *buffer, uint32_t total_sz) { - if(total_sz < sizeof(ethhdr)) - throw malformed_packet(); - memcpy(&_eth, buffer, sizeof(ethhdr)); - buffer += sizeof(ethhdr); - total_sz -= sizeof(ethhdr); - if(total_sz) { + InputMemoryStream stream(buffer, total_sz); + stream.read(_eth); + // If there's any size left + if (stream) { inner_pdu( Internals::pdu_from_flag( (Constants::Ethernet::e)payload_type(), - buffer, - total_sz + stream.pointer(), + stream.size() ) ); } diff --git a/src/icmp.cpp b/src/icmp.cpp index aa223ed..2842395 100644 --- a/src/icmp.cpp +++ b/src/icmp.cpp @@ -39,6 +39,9 @@ #include "utils.h" #include "exceptions.h" #include "icmp.h" +#include "memory_helpers.h" + +using Tins::Memory::InputMemoryStream; namespace Tins { @@ -51,36 +54,20 @@ ICMP::ICMP(Flags flag) ICMP::ICMP(const uint8_t *buffer, uint32_t total_sz) { - if(total_sz < sizeof(icmphdr)) - throw malformed_packet(); - std::memcpy(&_icmp, buffer, sizeof(icmphdr)); - buffer += sizeof(icmphdr); - total_sz -= sizeof(icmphdr); - uint32_t uint32_t_buffer = 0; + InputMemoryStream stream(buffer, total_sz); + stream.read(_icmp); if(type() == TIMESTAMP_REQUEST || type() == TIMESTAMP_REPLY) { - if(total_sz < sizeof(uint32_t) * 3) - throw malformed_packet(); - memcpy(&uint32_t_buffer, buffer, sizeof(uint32_t)); - original_timestamp(uint32_t_buffer); - memcpy(&uint32_t_buffer, buffer + sizeof(uint32_t), sizeof(uint32_t)); - receive_timestamp(uint32_t_buffer); - memcpy(&uint32_t_buffer, buffer + 2 * sizeof(uint32_t), sizeof(uint32_t)); - transmit_timestamp(uint32_t_buffer); - total_sz -= sizeof(uint32_t) * 3; - buffer += sizeof(uint32_t) * 3; + original_timestamp(stream.read()); + receive_timestamp(stream.read()); + transmit_timestamp(stream.read()); } else if(type() == ADDRESS_MASK_REQUEST || type() == ADDRESS_MASK_REPLY) { - if(total_sz < sizeof(uint32_t)) - throw malformed_packet(); - memcpy(&uint32_t_buffer, buffer, sizeof(uint32_t)); - address_mask(address_type(uint32_t_buffer)); - total_sz -= sizeof(uint32_t); - buffer += sizeof(uint32_t); + address_mask(address_type(stream.read())); } // Attempt to parse ICMP extensions - try_parse_extensions(buffer, total_sz); - if (total_sz) { - inner_pdu(new RawPDU(buffer, total_sz)); + try_parse_extensions(stream); + if (stream) { + inner_pdu(new RawPDU(stream.pointer(), stream.size())); } } @@ -288,10 +275,10 @@ uint32_t ICMP::get_adjusted_inner_pdu_size() const { return Internals::get_padded_icmp_inner_pdu_size(inner_pdu(), sizeof(uint32_t)); } -void ICMP::try_parse_extensions(const uint8_t* buffer, uint32_t& total_sz) { +void ICMP::try_parse_extensions(InputMemoryStream& stream) { // Check if this is one of the types defined in RFC 4884 if (are_extensions_allowed()) { - Internals::try_parse_icmp_extensions(buffer, total_sz, length() * sizeof(uint32_t), + Internals::try_parse_icmp_extensions(stream, length() * sizeof(uint32_t), extensions_); } } diff --git a/src/icmp_extension.cpp b/src/icmp_extension.cpp index 0e508fe..38a7cdd 100644 --- a/src/icmp_extension.cpp +++ b/src/icmp_extension.cpp @@ -3,9 +3,12 @@ #include "icmp_extension.h" #include "exceptions.h" #include "utils.h" +#include "memory_helpers.h" using std::runtime_error; +using Tins::Memory::InputMemoryStream; + namespace Tins { const uint32_t ICMPExtension::BASE_HEADER_SIZE = sizeof(uint16_t) + sizeof(uint8_t) * 2; @@ -24,22 +27,17 @@ ICMPExtension::ICMPExtension(uint8_t ext_class, uint8_t ext_type) ICMPExtension::ICMPExtension(const uint8_t* buffer, uint32_t total_sz) { - // Check for the base header (u16 length + u8 clss + u8 type) - if (total_sz < BASE_HEADER_SIZE) { - throw malformed_packet(); - } + InputMemoryStream stream(buffer, total_sz); - uint16_t length = Endian::be_to_host(*(const uint16_t*)buffer); - buffer += sizeof(uint16_t); - extension_class_ = *buffer++; - extension_type_ = *buffer++; - total_sz -= BASE_HEADER_SIZE; + uint16_t length = Endian::be_to_host(stream.read()); + extension_class_ = stream.read(); + extension_type_ = stream.read(); // Length is BASE_HEADER_SIZE + payload size, make sure it's valid - if (length < BASE_HEADER_SIZE || length - BASE_HEADER_SIZE > total_sz) { + if (length < BASE_HEADER_SIZE || length - BASE_HEADER_SIZE > stream.size()) { throw malformed_packet(); } length -= BASE_HEADER_SIZE; - payload_.assign(buffer, buffer + length); + payload_.assign(stream.pointer(), stream.pointer() + length); } void ICMPExtension::extension_class(uint8_t value) { @@ -88,19 +86,14 @@ ICMPExtensionsStructure::ICMPExtensionsStructure() } ICMPExtensionsStructure::ICMPExtensionsStructure(const uint8_t* buffer, uint32_t total_sz) { - if (total_sz < BASE_HEADER_SIZE) { - throw malformed_packet(); - } + InputMemoryStream stream(buffer, total_sz); - version_and_reserved_ = *(const uint16_t*)buffer; - buffer += sizeof(uint16_t); - checksum_ = *(const uint16_t*)buffer; - buffer += sizeof(uint16_t); - total_sz -= BASE_HEADER_SIZE; - while (total_sz > 0) { - extensions_.push_back(ICMPExtension(buffer, total_sz)); - uint16_t size = Endian::be_to_host(*(const uint16_t*)buffer); - total_sz -= size; + version_and_reserved_ = stream.read(); + checksum_ = stream.read(); + while (stream) { + extensions_.push_back(ICMPExtension(stream.pointer(), stream.size())); + uint16_t size = Endian::be_to_host(stream.read()); + stream.skip(size - sizeof(uint16_t)); } } diff --git a/src/icmpv6.cpp b/src/icmpv6.cpp index 75eecd9..22a50bc 100644 --- a/src/icmpv6.cpp +++ b/src/icmpv6.cpp @@ -37,6 +37,9 @@ #include "utils.h" #include "constants.h" #include "exceptions.h" +#include "memory_helpers.h" + +using Tins::Memory::InputMemoryStream; namespace Tins { @@ -50,63 +53,49 @@ ICMPv6::ICMPv6(Types tp) ICMPv6::ICMPv6(const uint8_t *buffer, uint32_t total_sz) : _options_size(), reach_time(0), retrans_timer(0) { - if (total_sz < sizeof(_header)) { - throw malformed_packet(); - } - std::memcpy(&_header, buffer, sizeof(_header)); - buffer += sizeof(_header); - total_sz -= sizeof(_header); + InputMemoryStream stream(buffer, total_sz); + stream.read(_header); if (has_target_addr()) { - if(total_sz < ipaddress_type::address_size) { - throw malformed_packet(); - } - target_addr(buffer); - buffer += ipaddress_type::address_size; - total_sz -= ipaddress_type::address_size; + _target_address = stream.read(); } if (has_dest_addr()) { - if(total_sz < ipaddress_type::address_size) { - throw malformed_packet(); - } - dest_addr(buffer); - buffer += ipaddress_type::address_size; - total_sz -= ipaddress_type::address_size; + _dest_address = stream.read(); } if (type() == ROUTER_ADVERT) { - if(total_sz < sizeof(uint32_t) * 2) { - throw malformed_packet(); - } - memcpy(&reach_time, buffer, sizeof(uint32_t)); - memcpy(&retrans_timer, buffer + sizeof(uint32_t), sizeof(uint32_t)); - - buffer += sizeof(uint32_t) * 2; - total_sz -= sizeof(uint32_t) * 2; + reach_time = stream.read(); + retrans_timer = stream.read(); } // Retrieve options if (has_options()) { - parse_options(buffer, total_sz); + parse_options(stream); } // Attempt to parse ICMP extensions - try_parse_extensions(buffer, total_sz); - if (total_sz) { - inner_pdu(new RawPDU(buffer, total_sz)); + try_parse_extensions(stream); + if (stream) { + inner_pdu(new RawPDU(stream.pointer(), stream.size())); } } -void ICMPv6::parse_options(const uint8_t *&buffer, uint32_t &total_sz) { - while(total_sz > 0) { - if(total_sz < 8 || (static_cast(buffer[1]) * 8) > total_sz || buffer[1] < 1) +void ICMPv6::parse_options(InputMemoryStream& stream) { + while (stream) { + const uint8_t opt_type = stream.read(); + const uint32_t opt_size = static_cast(stream.read()) * 8; + if (opt_size < sizeof(uint8_t) << 1) { throw malformed_packet(); + } // size(option) = option_size - identifier_size - length_identifier_size + const uint32_t payload_size = opt_size - (sizeof(uint8_t) << 1); + if (!stream.can_read(payload_size)) { + throw malformed_packet(); + } add_option( option( - buffer[0], - static_cast(buffer[1]) * 8 - sizeof(uint8_t) * 2, - buffer + 2 + opt_type, + payload_size, + stream.pointer() ) ); - total_sz -= buffer[1] * 8; - buffer += buffer[1] * 8; + stream.skip(payload_size); } } @@ -660,10 +649,10 @@ uint32_t ICMPv6::get_adjusted_inner_pdu_size() const { return Internals::get_padded_icmp_inner_pdu_size(inner_pdu(), sizeof(uint64_t)); } -void ICMPv6::try_parse_extensions(const uint8_t* buffer, uint32_t& total_sz) { +void ICMPv6::try_parse_extensions(InputMemoryStream& stream) { // Check if this is one of the types defined in RFC 4884 if (are_extensions_allowed()) { - Internals::try_parse_icmp_extensions(buffer, total_sz, length() * sizeof(uint64_t), + Internals::try_parse_icmp_extensions(stream, length() * sizeof(uint64_t), extensions_); } } diff --git a/src/internals.cpp b/src/internals.cpp index 98364a7..033e0d1 100644 --- a/src/internals.cpp +++ b/src/internals.cpp @@ -52,11 +52,15 @@ #include "ip_address.h" #include "ipv6_address.h" #include "pdu_allocator.h" +#include "memory_helpers.h" using std::string; +using Tins::Memory::InputMemoryStream; + namespace Tins { namespace Internals { + bool from_hex(const string &str, uint32_t &result) { unsigned i(0); result = 0; @@ -275,9 +279,9 @@ uint32_t get_padded_icmp_inner_pdu_size(const PDU* inner_pdu, uint32_t pad_align } } -void try_parse_icmp_extensions(const uint8_t* buffer, uint32_t& total_sz, - uint32_t payload_length, ICMPExtensionsStructure& extensions) { - if (total_sz == 0) { +void try_parse_icmp_extensions(InputMemoryStream& stream, uint32_t payload_length, + ICMPExtensionsStructure& extensions) { + if (!stream) { return; } // Check if this is one of the types defined in RFC 4884 @@ -286,15 +290,15 @@ void try_parse_icmp_extensions(const uint8_t* buffer, uint32_t& total_sz, // the minimum encapsulated packet size const uint8_t* extensions_ptr; uint32_t extensions_size; - if (payload_length < total_sz && payload_length >= minimum_payload) { - extensions_ptr = buffer + payload_length; - extensions_size = total_sz - payload_length; + if (stream.can_read(payload_length) && payload_length >= minimum_payload) { + extensions_ptr = stream.pointer() + payload_length; + extensions_size = stream.size() - payload_length; } - else if (total_sz > minimum_payload) { + else if (stream.can_read(minimum_payload)) { // This packet might be non-rfc compliant. In that case the length // field can contain garbage. - extensions_ptr = buffer + minimum_payload; - extensions_size = total_sz - minimum_payload; + extensions_ptr = stream.pointer() + minimum_payload; + extensions_size = stream.size() - minimum_payload; } else { // No more special cases, this doesn't have extensions @@ -302,7 +306,7 @@ void try_parse_icmp_extensions(const uint8_t* buffer, uint32_t& total_sz, } if (ICMPExtensionsStructure::validate_extensions(extensions_ptr, extensions_size)) { extensions = ICMPExtensionsStructure(extensions_ptr, extensions_size); - total_sz -= extensions_size; + stream.size(stream.size() - extensions_size); } } diff --git a/src/ip.cpp b/src/ip.cpp index 2a8d4c9..944b8d9 100644 --- a/src/ip.cpp +++ b/src/ip.cpp @@ -48,9 +48,12 @@ #include "network_interface.h" #include "exceptions.h" #include "pdu_allocator.h" +#include "memory_helpers.h" using std::list; +using Tins::Memory::InputMemoryStream; + namespace Tins { const uint8_t IP::DEFAULT_TTL = 128; @@ -62,46 +65,51 @@ IP::IP(address_type ip_dst, address_type ip_src) this->src_addr(ip_src); } -IP::IP(const uint8_t *buffer, uint32_t total_sz) +IP::IP(const uint8_t *buffer, uint32_t total_sz) +: _options_size(0) { - if(total_sz < sizeof(iphdr)) - throw malformed_packet(); - std::memcpy(&_ip, buffer, sizeof(iphdr)); + InputMemoryStream stream(buffer, total_sz); + stream.read(_ip); - /* Options... */ - /* Establish beginning and ending of the options */ - const uint8_t* ptr_buffer = buffer + sizeof(iphdr); - if(total_sz < head_len() * sizeof(uint32_t)) + // Make sure we have enough size for options and not less than we should + if (head_len() * sizeof(uint32_t) > total_sz || + head_len() * sizeof(uint32_t) < sizeof(iphdr)) { throw malformed_packet(); - if(head_len() * sizeof(uint32_t) < sizeof(iphdr)) - throw malformed_packet(); - buffer += head_len() * sizeof(uint32_t); + } + const uint8_t* options_end = buffer + head_len() * sizeof(uint32_t); - _options_size = 0; - //_padded_options_size = head_len() * sizeof(uint32_t) - sizeof(iphdr); - /* While the end of the options is not reached read an option */ - while (ptr_buffer < buffer && (*ptr_buffer != 0)) { - //ip_option opt_to_add; - option_identifier opt_type; - memcpy(&opt_type, ptr_buffer, sizeof(uint8_t)); - ptr_buffer++; - if(opt_type.number > NOOP) { - /* Multibyte options with length as second byte */ - if(ptr_buffer == buffer || *ptr_buffer == 0) + // While the end of the options is not reached read an option + while (stream.pointer() < options_end) { + option_identifier opt_type = (option_identifier)stream.read(); + if (opt_type.number > NOOP) { + // Multibyte options with length as second byte + const uint32_t option_size = stream.read(); + if (TINS_UNLIKELY(option_size < (sizeof(uint8_t) << 1))) { throw malformed_packet(); - - const uint8_t data_size = *ptr_buffer - 2; - if(data_size > 0) { - ptr_buffer++; - if(buffer - ptr_buffer < data_size) - throw malformed_packet(); - _ip_options.push_back(option(opt_type, ptr_buffer, ptr_buffer + data_size)); } - else + // The data size is the option size - the identifier and size fields + const uint32_t data_size = option_size - (sizeof(uint8_t) << 1); + if (data_size > 0) { + if (stream.pointer() + data_size > options_end) { + throw malformed_packet(); + } + _ip_options.push_back( + option(opt_type, stream.pointer(), stream.pointer() + data_size) + ); + stream.skip(data_size); + } + else { _ip_options.push_back(option(opt_type)); - - ptr_buffer += _ip_options.back().data_size() + 1; - _options_size += static_cast(_ip_options.back().data_size() + 2); + } + _options_size += option_size; + } + else if (opt_type == END) { + // If the end option found, we're done + if (TINS_UNLIKELY(stream.pointer() != options_end)) { + // Make sure we found the END option at the end of the options list + throw malformed_packet(); + } + break; } else { _ip_options.push_back(option(opt_type)); @@ -110,39 +118,43 @@ IP::IP(const uint8_t *buffer, uint32_t total_sz) } uint8_t padding = _options_size % 4; _padded_options_size = padding ? (_options_size - padding + 4) : _options_size; - // Don't avoid consuming more than we should if tot_len is 0, - // since this is the case when using TCP segmentation offload - if (tot_len() != 0) - total_sz = std::min(total_sz, (uint32_t)tot_len()); - if (total_sz < head_len() * sizeof(uint32_t)) - throw malformed_packet(); - total_sz -= head_len() * sizeof(uint32_t); - if (total_sz) { + if (stream) { + // Don't avoid consuming more than we should if tot_len is 0, + // since this is the case when using TCP segmentation offload + if (tot_len() != 0) { + const uint32_t advertised_length = (uint32_t)tot_len() - head_len() * sizeof(uint32_t); + total_sz = std::min(stream.size(), advertised_length); + } + else { + total_sz = stream.size(); + } + // Don't try to decode it if it's fragmented - if(!is_fragmented()) { + if (!is_fragmented()) { inner_pdu( Internals::pdu_from_flag( static_cast(_ip.protocol), - buffer, + stream.pointer(), total_sz, false ) ); - if(!inner_pdu()) { + if (!inner_pdu()) { inner_pdu( Internals::allocate( _ip.protocol, - buffer, + stream.pointer(), total_sz ) ); - if(!inner_pdu()) - inner_pdu(new RawPDU(buffer, total_sz)); + if (!inner_pdu()) { + inner_pdu(new RawPDU(stream.pointer(), total_sz)); + } } } else { // It's fragmented, just use RawPDU - inner_pdu(new RawPDU(buffer, total_sz)); + inner_pdu(new RawPDU(stream.pointer(), total_sz)); } } } diff --git a/src/tcp.cpp b/src/tcp.cpp index 74094c8..d004801 100644 --- a/src/tcp.cpp +++ b/src/tcp.cpp @@ -38,17 +38,18 @@ #include "utils.h" #include "exceptions.h" #include "internals.h" +#include "memory_helpers.h" using std::find_if; +using Tins::Memory::InputMemoryStream; namespace Tins { const uint16_t TCP::DEFAULT_WINDOW = 32678; TCP::TCP(uint16_t dport, uint16_t sport) -: _options_size(0), _total_options_size(0) +: _tcp(), _options_size(0), _total_options_size(0) { - std::memset(&_tcp, 0, sizeof(tcphdr)); this->dport(dport); this->sport(sport); data_offset(sizeof(tcphdr) / sizeof(uint32_t)); @@ -56,45 +57,54 @@ TCP::TCP(uint16_t dport, uint16_t sport) } TCP::TCP(const uint8_t *buffer, uint32_t total_sz) +: _options_size(0), _total_options_size(0) { - if(total_sz < sizeof(tcphdr)) - throw malformed_packet(); - std::memcpy(&_tcp, buffer, sizeof(tcphdr)); - if(data_offset() * sizeof(uint32_t) > total_sz || data_offset() * sizeof(uint32_t) < sizeof(tcphdr)) + InputMemoryStream stream(buffer, total_sz); + stream.read(_tcp); + // Check that we have at least the amount of bytes we need and not less + if (TINS_UNLIKELY(data_offset() * sizeof(uint32_t) > total_sz || + data_offset() * sizeof(uint32_t) < sizeof(tcphdr))) { throw malformed_packet(); + } const uint8_t *header_end = buffer + (data_offset() * sizeof(uint32_t)); - total_sz = static_cast(total_sz - (header_end - buffer)); - buffer += sizeof(tcphdr); - - _total_options_size = 0; - _options_size = 0; - while(buffer < header_end) { - if(*buffer <= NOP) { + while (stream.pointer() < header_end) { + const OptionTypes option_type = (OptionTypes)stream.read(); + if (option_type <= NOP) { #if TINS_IS_CXX11 - add_option((OptionTypes)*buffer, 0); + add_option(option_type, 0); #else - add_option(option((OptionTypes)*buffer, 0)); + add_option(option(option_type, 0)); #endif // TINS_IS_CXX11 - ++buffer; } else { - if(buffer + 1 == header_end) + // Extract the length + uint32_t len = stream.read(); + const uint8_t *data_start = stream.pointer(); + + // We need to subtract the option type and length from the size + if (TINS_UNLIKELY(len < sizeof(uint8_t) << 1)) { throw malformed_packet(); - const uint8_t len = buffer[1] - (sizeof(uint8_t) << 1); - const uint8_t *data_start = buffer + 2; - if(data_start + len > header_end) + } + len -= (sizeof(uint8_t) << 1); + // Make sure we have enough bytes for the advertised option payload length + if (TINS_UNLIKELY(data_start + len > header_end)) { throw malformed_packet(); + } + // If we're using C++11, use the variadic template overload #if TINS_IS_CXX11 - add_option((OptionTypes)*buffer, data_start, data_start + len); + add_option(option_type, data_start, data_start + len); #else - add_option(option((OptionTypes)*buffer, data_start, data_start + len)); + add_option(option(option_type, data_start, data_start + len)); #endif // TINS_IS_CXX11 - buffer = data_start + len; + // Skip the option's payload + stream.skip(len); } } - if(total_sz) - inner_pdu(new RawPDU(buffer, total_sz)); + // If we still have any bytes left + if (stream) { + inner_pdu(new RawPDU(stream.pointer(), stream.size())); + } } void TCP::dport(uint16_t new_dport) {