diff --git a/include/packet_writer.h b/include/packet_writer.h index 28258f5..b015a90 100644 --- a/include/packet_writer.h +++ b/include/packet_writer.h @@ -98,6 +98,18 @@ public: */ void write(PDU &pdu); + /** + * \brief Writes a PDU to this file. + * + * The template parameter T must at some point yield a PDU& after + * applying operator* one or more than one time. This accepts both + * raw and smartpointers. + */ + template + void write(T &pdu) { + write(Utils::dereference_until_pdu(pdu)); + } + /** * \brief Writes all the PDUs in the range [start, end) * \param start A forward iterator pointing to the first PDU diff --git a/src/ip.cpp b/src/ip.cpp index 38dee63..97ee561 100644 --- a/src/ip.cpp +++ b/src/ip.cpp @@ -378,6 +378,7 @@ void IP::prepare_for_serialize(const PDU *parent) { void IP::write_serialization(uint8_t *buffer, uint32_t total_sz, const PDU* parent) { uint32_t my_sz = header_size(); assert(total_sz >= my_sz); + check(0); if(inner_pdu()) { uint32_t new_flag; switch(inner_pdu()->pdu_type()) { @@ -415,12 +416,12 @@ void IP::write_serialization(uint8_t *buffer, uint32_t total_sz, const PDU* pare ptr_buffer = write_option(*it, ptr_buffer); memset(buffer + sizeof(_ip) + _options_size, 0, _padded_options_size - _options_size); - if(parent && !_ip.check) { + if(parent) { uint32_t checksum = Utils::do_checksum(buffer, buffer + sizeof(_ip) + _padded_options_size); while (checksum >> 16) checksum = (checksum & 0xffff) + (checksum >> 16); - ((iphdr*)buffer)->check = Endian::host_to_be(~checksum); - this->check(0); + check(~checksum); + ((iphdr*)buffer)->check = _ip.check; } } diff --git a/src/tcp.cpp b/src/tcp.cpp index 05a806d..f5a65b3 100644 --- a/src/tcp.cpp +++ b/src/tcp.cpp @@ -284,6 +284,7 @@ uint32_t TCP::header_size() const { void TCP::write_serialization(uint8_t *buffer, uint32_t total_sz, const PDU *parent) { assert(total_sz >= header_size()); uint8_t *tcp_start = buffer; + check(0); buffer += sizeof(tcphdr); _tcp.doff = (sizeof(tcphdr) + _total_options_size) / sizeof(uint32_t); for(options_type::iterator it = _options.begin(); it != _options.end(); ++it) @@ -299,34 +300,30 @@ void TCP::write_serialization(uint8_t *buffer, uint32_t total_sz, const PDU *par memcpy(tcp_start, &_tcp, sizeof(tcphdr)); - if(!_tcp.check) { - const Tins::IP *ip_packet = dynamic_cast(parent); - if(ip_packet) { - uint32_t checksum = Utils::pseudoheader_checksum(ip_packet->src_addr(), - ip_packet->dst_addr(), - size(), Constants::IP::PROTO_TCP) + - Utils::do_checksum(tcp_start, tcp_start + total_sz); + const Tins::IP *ip_packet = dynamic_cast(parent); + if(ip_packet) { + uint32_t checksum = Utils::pseudoheader_checksum(ip_packet->src_addr(), + ip_packet->dst_addr(), + size(), Constants::IP::PROTO_TCP) + + Utils::do_checksum(tcp_start, tcp_start + total_sz); + while (checksum >> 16) + checksum = (checksum & 0xffff) + (checksum >> 16); + check(~checksum); + ((tcphdr*)tcp_start)->check = _tcp.check; + } + else { + const Tins::IPv6 *ipv6_packet = dynamic_cast(parent); + if(ipv6_packet) { + uint32_t checksum = Utils::pseudoheader_checksum(ipv6_packet->src_addr(), + ipv6_packet->dst_addr(), + size(), Constants::IP::PROTO_TCP) + + Utils::do_checksum(tcp_start, tcp_start + total_sz); while (checksum >> 16) checksum = (checksum & 0xffff) + (checksum >> 16); - - ((tcphdr*)tcp_start)->check = Endian::host_to_be(~checksum); - } - else { - const Tins::IPv6 *ipv6_packet = dynamic_cast(parent); - if(ipv6_packet) { - uint32_t checksum = Utils::pseudoheader_checksum(ipv6_packet->src_addr(), - ipv6_packet->dst_addr(), - size(), Constants::IP::PROTO_TCP) + - Utils::do_checksum(tcp_start, tcp_start + total_sz); - while (checksum >> 16) - checksum = (checksum & 0xffff) + (checksum >> 16); - - ((tcphdr*)tcp_start)->check = Endian::host_to_be(~checksum); - } + check(~checksum); + ((tcphdr*)tcp_start)->check = _tcp.check; } } - - _tcp.check = 0; } const TCP::option *TCP::search_option(OptionTypes opt) const {