diff --git a/include/packet_writer.h b/include/packet_writer.h index 4c21a53..e225e6d 100644 --- a/include/packet_writer.h +++ b/include/packet_writer.h @@ -33,6 +33,7 @@ #include #include #include +#include "utils.h" namespace Tins { class PDU; @@ -82,27 +83,10 @@ public: */ template void write(ForwardIterator start, ForwardIterator end) { - typedef typename std::iterator_traits::value_type value_type; - typedef derefer deref_type; - while(start != end) - write(deref_type::deref(*start++)); + write(Utils::dereference_until_pdu(*start++)); } private: - template - struct derefer { - static T &deref(T &value) { - return value; - } - }; - - template - struct derefer { - static T &deref(T *value) { - return *value; - } - }; - // You shall not copy PacketWriter(const PacketWriter&); PacketWriter& operator=(const PacketWriter&); diff --git a/include/pdu_cacher.h b/include/pdu_cacher.h index d29681d..961f3ea 100644 --- a/include/pdu_cacher.h +++ b/include/pdu_cacher.h @@ -66,26 +66,22 @@ public: /** * Default constructs the cached PDU. */ - PDUCacher() {} + PDUCacher() : cached_size() {} /** * Constructor from a cached_type. * \param pdu The PDU to be copy constructed. */ - PDUCacher(const cached_type &pdu) : cached(pdu) {} + PDUCacher(const cached_type &pdu) : cached(pdu), + cached_size() {} /** * Forwards the call to the cached PDU. \sa PDU::header_size. */ uint32_t header_size() const { - return cached.header_size(); - } - - /** - * Forwards the call to the cached PDU. \sa PDU::trailer_size. - */ - uint32_t trailer_size() const { - return cached.trailer_size(); + if(cached_serialization.empty()) + cached_size = cached.size(); + return cached_size; } /** @@ -98,8 +94,8 @@ public: /** * Forwards the call to the cached PDU. \sa PDU::send. */ - bool send(PacketSender &sender) { - return cached.send(sender); + void send(PacketSender &sender) { + cached.send(sender); } /** @@ -138,13 +134,15 @@ public: } private: void write_serialization(uint8_t *buffer, uint32_t total_sz, const PDU *parent) { - if(cached_serialization.size() != total_sz) + if(cached_serialization.size() != total_sz) { cached_serialization = cached.serialize(); + } std::copy(cached_serialization.begin(), cached_serialization.end(), buffer); } cached_type cached; PDU::serialization_type cached_serialization; + mutable uint32_t cached_size; }; } diff --git a/include/small_uint.h b/include/small_uint.h index 4202346..84d10e6 100644 --- a/include/small_uint.h +++ b/include/small_uint.h @@ -34,6 +34,20 @@ #include namespace Tins { +class value_too_large : public std::exception { +public: + const char *what() const throw() { + return "Value is too large"; + } +}; + +/** + * \class small_uint + * \brief Represents a field of n bits. + * + * This finds the best integral type of at least n bits and + * uses it to store the wrapped value. + */ template class small_uint { private: @@ -74,24 +88,38 @@ private: static const uint64_t value = 1; }; public: - class value_to_large : public std::exception { - public: - const char *what() const throw() { - return "Value is too large"; - } - }; - + /** + * The type used to store the value. + */ typedef typename best_type::type repr_type; + + /** + * The maximum value this class can hold. + */ static const repr_type max_value = power<2, n>::value - 1; + /** + * Value initializes the value. + */ small_uint() : value() {} + /** + * \brief Copy constructs the stored value. + * + * This throws a value_too_large exception if the value provided + * is larger than max_value. + * + * \param val The parameter from which to copy construct. + */ small_uint(repr_type val) { if(val > max_value) - throw value_to_large(); + throw value_too_large(); value = val; } + /** + * User defined conversion to repr_type. + */ operator repr_type() const { return value; } diff --git a/include/tcp_stream.h b/include/tcp_stream.h index 5ae7017..d1580a2 100644 --- a/include/tcp_stream.h +++ b/include/tcp_stream.h @@ -38,6 +38,7 @@ #include #include "sniffer.h" #include "tcp.h" +#include "utils.h" #include "ip.h" #include "ip_address.h" @@ -297,37 +298,6 @@ private: EndFunctor end_fun; }; - template - struct is_pdu { - template - static char test(typename U::PDUType*); - - template - static long test(...); - - static const bool value = sizeof(test(0)) == 1; - }; - - template - struct enable_if { - - }; - - template - struct enable_if { - typedef T type; - }; - - static PDU& recursive_dereference(PDU &pdu) { - return pdu; - } - - template - static typename enable_if::value, PDU&>::type - recursive_dereference(T &value) { - return recursive_dereference(*value); - } - void clear_state() { sessions.clear(); last_identifier = 0; @@ -355,7 +325,7 @@ void TCPStreamFollower::follow_streams(ForwardIterator start, ForwardIterator en { clear_state(); while(start != end) { - if(!callback(recursive_dereference(start), data_fun, end_fun)) + if(!callback(Utils::dereference_until_pdu(start), data_fun, end_fun)) return; start++; } diff --git a/include/tins.h b/include/tins.h index 280d16c..bdcc297 100644 --- a/include/tins.h +++ b/include/tins.h @@ -54,5 +54,6 @@ #include "tcp_stream.h" #include "crypto.h" #include "pdu_cacher.h" +#include "rsn_information.h" #endif // TINS_TINS_H diff --git a/include/utils.h b/include/utils.h index b3a5fc9..ea81a21 100644 --- a/include/utils.h +++ b/include/utils.h @@ -228,9 +228,57 @@ namespace Tins { } #endif // WIN32 + /** + * \cond + */ namespace Internals { void skip_line(std::istream &input); bool from_hex(const std::string &str, uint32_t &result); + + template + struct enable_if { + + }; + + template + struct enable_if { + typedef T type; + }; + } + /** + * \endcond + */ + + template + struct is_pdu { + template + static char test(typename U::PDUType*); + + template + static long test(...); + + static const bool value = sizeof(test(0)) == 1; + }; + + /** + * Returns the argument. + */ + inline PDU& dereference_until_pdu(PDU &pdu) { + return pdu; + } + + /** + * \brief Dereferences the parameter until a PDU is found. + * + * This function dereferences the parameter until a PDU object + * is found. When it's found, it is returned. + * + * \param value The parameter to be dereferenced. + */ + template + inline typename Internals::enable_if::value, PDU&>::type + dereference_until_pdu(T &value) { + return dereference_until_pdu(*value); } } } diff --git a/src/dns_record.cpp b/src/dns_record.cpp index 7555e1a..e64127e 100644 --- a/src/dns_record.cpp +++ b/src/dns_record.cpp @@ -29,6 +29,7 @@ #include #include +#include #include #include "dns_record.h" #include "endianness.h" @@ -50,10 +51,11 @@ DNSResourceRecord::DNSResourceRecord(DNSRRImpl *impl, DNSResourceRecord::DNSResourceRecord(const uint8_t *buffer, uint32_t size) { const uint8_t *buffer_end = buffer + size; + std::auto_ptr tmp_impl; if((*buffer & 0xc0)) { uint16_t offset(*reinterpret_cast(buffer)); offset = Endian::be_to_host(offset) & 0x3fff; - impl = new OffsetedDNSRRImpl(Endian::host_to_be(offset)); + tmp_impl.reset(new OffsetedDNSRRImpl(Endian::host_to_be(offset))); buffer += sizeof(uint16_t); } else { @@ -63,7 +65,7 @@ DNSResourceRecord::DNSResourceRecord(const uint8_t *buffer, uint32_t size) if(str_end == buffer_end) throw std::runtime_error("Not enough size for a resource domain name."); //str_end++; - impl = new NamedDNSRRImpl(buffer, str_end); + tmp_impl.reset(new NamedDNSRRImpl(buffer, str_end)); buffer = ++str_end; } if(buffer + sizeof(info_) > buffer_end) @@ -86,6 +88,7 @@ DNSResourceRecord::DNSResourceRecord(const uint8_t *buffer, uint32_t size) *(uint32_t*)&data[0] = *(uint32_t*)buffer; else throw std::runtime_error("Not enough size for resource data"); + impl = tmp_impl.release(); } DNSResourceRecord::DNSResourceRecord(const DNSResourceRecord &rhs) diff --git a/src/ip.cpp b/src/ip.cpp index dca65f7..f105844 100644 --- a/src/ip.cpp +++ b/src/ip.cpp @@ -77,7 +77,7 @@ IP::IP(const uint8_t *buffer, uint32_t total_sz) buffer += head_len() * sizeof(uint32_t); _options_size = 0; - _padded_options_size = head_len() * sizeof(uint32_t) - sizeof(iphdr); + //_padded_options_size = head_len() * sizeof(uint32_t) - sizeof(iphdr); /* While the end of the options is not reached read an option */ while (ptr_buffer < buffer && (*ptr_buffer != 0)) { //ip_option opt_to_add; @@ -126,6 +126,8 @@ IP::IP(const uint8_t *buffer, uint32_t total_sz) } _options_size += _ip_options.back().data_size() + 2; } + uint8_t padding = _options_size % 4; + _padded_options_size = padding ? (_options_size - padding + 4) : _options_size; // check this line PLX total_sz -= head_len() * sizeof(uint32_t); if (total_sz) { @@ -304,7 +306,7 @@ uint16_t IP::stream_identifier() const { void IP::add_option(const ip_option &option) { _ip_options.push_back(option); _options_size += 1 + option.data_size(); - uint8_t padding = _options_size & 3; + uint8_t padding = _options_size % 4; _padded_options_size = padding ? (_options_size - padding + 4) : _options_size; } diff --git a/src/rsn_information.cpp b/src/rsn_information.cpp index f730061..a441518 100644 --- a/src/rsn_information.cpp +++ b/src/rsn_information.cpp @@ -51,21 +51,20 @@ RSNInformation::RSNInformation(const uint8_t *buffer, uint32_t total_sz) { void RSNInformation::init(const uint8_t *buffer, uint32_t total_sz) { const char *err_msg = "Malformed RSN information structure"; - check_size(total_sz, err_msg); + if(total_sz <= sizeof(uint16_t) * 2 + sizeof(uint32_t)) + throw std::runtime_error(err_msg); version(Endian::le_to_host(*(uint16_t*)buffer)); buffer += sizeof(uint16_t); total_sz -= sizeof(uint16_t); - group_suite((RSNInformation::CypherSuites)*(uint32_t*)buffer); - check_size(total_sz, err_msg); + group_suite((RSNInformation::CypherSuites)*(uint32_t*)buffer); buffer += sizeof(uint32_t); total_sz -= sizeof(uint32_t); - check_size(total_sz, err_msg); - uint16_t count = *(uint16_t*)buffer; buffer += sizeof(uint16_t); total_sz -= sizeof(uint16_t); + if(count * sizeof(uint32_t) > total_sz) throw std::runtime_error(err_msg); total_sz -= count * sizeof(uint32_t); diff --git a/tests/src/utils_test.cpp b/tests/src/utils_test.cpp index bf05c96..d356d9b 100644 --- a/tests/src/utils_test.cpp +++ b/tests/src/utils_test.cpp @@ -77,12 +77,12 @@ TEST_F(UtilsTest, Crc32) { // FIXME TEST_F(UtilsTest, Checksum) { - uint16_t checksum = Utils::do_checksum(data, data + data_len); + /*uint16_t checksum = Utils::do_checksum(data, data + data_len); //EXPECT_EQ(checksum, 0x231a); uint8_t my_data[] = {0, 0, 0, 0}; checksum = Utils::do_checksum(my_data, my_data + 4); //EXPECT_EQ(checksum, 0xFFFF); - +*/ }