diff --git a/src/icmp.cpp b/src/icmp.cpp index 3286a76..ef12b4b 100644 --- a/src/icmp.cpp +++ b/src/icmp.cpp @@ -262,12 +262,12 @@ void ICMP::write_serialization(uint8_t *buffer, uint32_t total_sz, const PDU *) } // Calculate checksum - uint32_t checksum = Utils::do_checksum(buffer, buffer + total_sz); + uint32_t checksum = Utils::sum_range(buffer, buffer + total_sz); while (checksum >> 16) { checksum = (checksum & 0xffff) + (checksum >> 16); } // Write back only the 2 checksum bytes - _icmp.check = Endian::host_to_be(~checksum); + _icmp.check = ~checksum; memcpy(buffer + 2, &_icmp.check, sizeof(uint16_t)); } diff --git a/src/icmpv6.cpp b/src/icmpv6.cpp index 583c578..ad679d5 100644 --- a/src/icmpv6.cpp +++ b/src/icmpv6.cpp @@ -274,11 +274,11 @@ void ICMPv6::write_serialization(uint8_t *buffer, uint32_t total_sz, const PDU * ipv6->dst_addr(), size(), Constants::IP::PROTO_ICMPV6 - ) + Utils::do_checksum(buffer, buffer + total_sz); + ) + Utils::sum_range(buffer, buffer + total_sz); while (checksum >> 16) { checksum = (checksum & 0xffff) + (checksum >> 16); } - this->checksum(~checksum); + this->checksum(Endian::host_to_be(~checksum)); memcpy(buffer + 2, &_header.cksum, sizeof(uint16_t)); } } diff --git a/src/tcp.cpp b/src/tcp.cpp index 508abef..bbfaecc 100644 --- a/src/tcp.cpp +++ b/src/tcp.cpp @@ -306,34 +306,32 @@ void TCP::write_serialization(uint8_t *buffer, uint32_t total_sz, const PDU *par stream.fill(padding, 1); } - const Tins::IP *ip_packet = tins_cast(parent); - if(ip_packet) { - uint32_t check = Utils::pseudoheader_checksum( + uint32_t check = 0; + if (const Tins::IP *ip_packet = tins_cast(parent)) { + check = Utils::pseudoheader_checksum( ip_packet->src_addr(), ip_packet->dst_addr(), size(), - Constants::IP::PROTO_TCP) + Utils::do_checksum(buffer, buffer + total_sz); - while (check >> 16) { - check = (check & 0xffff) + (check >> 16); - } - checksum(~check); - ((tcphdr*)buffer)->check = _tcp.check; + Constants::IP::PROTO_TCP + ) + Utils::sum_range(buffer, buffer + total_sz); + } + else if (const Tins::IPv6 *ipv6_packet = tins_cast(parent)) { + check = Utils::pseudoheader_checksum( + ipv6_packet->src_addr(), + ipv6_packet->dst_addr(), + size(), + Constants::IP::PROTO_TCP + ) + Utils::sum_range(buffer, buffer + total_sz); } else { - const Tins::IPv6 *ipv6_packet = tins_cast(parent); - if(ipv6_packet) { - uint32_t check = Utils::pseudoheader_checksum( - ipv6_packet->src_addr(), - ipv6_packet->dst_addr(), - size(), - Constants::IP::PROTO_TCP) + Utils::do_checksum(buffer, buffer + total_sz); - while (check >> 16) { - check = (check & 0xffff) + (check >> 16); - } - checksum(~check); - ((tcphdr*)buffer)->check = _tcp.check; - } + return; } + // Convert this 32-bit value into a 16-bit value + while (check >> 16) { + check = (check & 0xffff) + (check >> 16); + } + checksum(Endian::host_to_be(~check)); + ((tcphdr*)buffer)->check = _tcp.check; } const TCP::option *TCP::search_option(OptionTypes type) const { diff --git a/src/udp.cpp b/src/udp.cpp index 02449d5..e84557c 100644 --- a/src/udp.cpp +++ b/src/udp.cpp @@ -74,6 +74,43 @@ uint32_t UDP::header_size() const { return sizeof(udphdr); } +uint32_t sum_range(const uint8_t *start, const uint8_t *end) { + uint32_t checksum(0); + const uint8_t *last = end; + uint16_t buffer = 0; + uint16_t padding = 0; + const uint8_t *ptr = start; + + if(((end - start) & 1) == 1) { + last = end - 1; + padding = Endian::host_to_le(*(end - 1)); + } + + while(ptr < last) { + memcpy(&buffer, ptr, sizeof(uint16_t)); + checksum += buffer; + ptr += sizeof(uint16_t); + } + + checksum += padding; + return checksum; +} + +uint32_t pseudoheader_checksum(IPv4Address source_ip, IPv4Address dest_ip, uint32_t len, uint32_t flag) { + uint32_t checksum(0); + uint8_t buffer[sizeof(uint32_t) * 3]; + OutputMemoryStream stream(buffer, sizeof(buffer)); + stream.write(source_ip); + stream.write(dest_ip); + stream.write(Endian::host_to_be(flag)); + stream.write(Endian::host_to_be(len)); + uint16_t *ptr = (uint16_t*)buffer, *end = (uint16_t*)(buffer + sizeof(buffer)); + while (ptr < end) { + checksum += *ptr++; + } + return checksum; +} + void UDP::write_serialization(uint8_t *buffer, uint32_t total_sz, const PDU *parent) { OutputMemoryStream stream(buffer, total_sz); // Set checksum to 0, we'll calculate it at the end @@ -85,36 +122,31 @@ void UDP::write_serialization(uint8_t *buffer, uint32_t total_sz, const PDU *par length(static_cast(sizeof(udphdr))); } stream.write(_udp); - const Tins::IP *ip_packet = tins_cast(parent); - if(ip_packet) { - uint32_t checksum = Utils::pseudoheader_checksum( + uint32_t checksum = 0; + if (const Tins::IP *ip_packet = tins_cast(parent)) { + checksum = Utils::pseudoheader_checksum( ip_packet->src_addr(), ip_packet->dst_addr(), size(), Constants::IP::PROTO_UDP - ) + Utils::do_checksum(buffer, buffer + total_sz); - while (checksum >> 16) { - checksum = (checksum & 0xffff)+(checksum >> 16); - } - _udp.check = Endian::host_to_be(~checksum); - ((udphdr*)buffer)->check = _udp.check; + ) + Utils::sum_range(buffer, buffer + total_sz); + } + else if (const Tins::IPv6 *ip6_packet = tins_cast(parent)) { + checksum = Utils::pseudoheader_checksum( + ip6_packet->src_addr(), + ip6_packet->dst_addr(), + size(), + Constants::IP::PROTO_UDP + ) + Utils::sum_range(buffer, buffer + total_sz); } else { - const Tins::IPv6 *ip6_packet = tins_cast(parent); - if(ip6_packet) { - uint32_t checksum = Utils::pseudoheader_checksum( - ip6_packet->src_addr(), - ip6_packet->dst_addr(), - size(), - Constants::IP::PROTO_UDP - ) + Utils::do_checksum(buffer, buffer + total_sz); - while (checksum >> 16) { - checksum = (checksum & 0xffff)+(checksum >> 16); - } - _udp.check = Endian::host_to_be(~checksum); - ((udphdr*)buffer)->check = _udp.check; - } + return; } + while (checksum >> 16) { + checksum = (checksum & 0xffff)+(checksum >> 16); + } + _udp.check = ~checksum; + ((udphdr*)buffer)->check = _udp.check; } bool UDP::matches_response(const uint8_t *ptr, uint32_t total_sz) const { diff --git a/src/utils.cpp b/src/utils.cpp index c593e57..4fe9e57 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -56,9 +56,11 @@ #include "network_interface.h" #include "packet_sender.h" #include "cxxstd.h" +#include "memory_helpers.h" using namespace std; +using Tins::Memory::OutputMemoryStream; /** \cond */ struct InterfaceCollector { @@ -248,33 +250,34 @@ uint16_t sum_range(const uint8_t *start, const uint8_t *end) { return checksum; } -uint32_t pseudoheader_checksum(IPv4Address source_ip, IPv4Address dest_ip, uint32_t len, uint32_t flag) { +template +uint32_t generic_pseudoheader_checksum(const AddressType& source_ip, const AddressType& dest_ip, + uint16_t len, uint16_t flag) { uint32_t checksum(0); - uint32_t source_ip_int = Endian::host_to_be(source_ip), - dest_ip_int = Endian::host_to_be(dest_ip); - char buffer[sizeof(uint32_t) * 2]; + uint8_t buffer[buffer_size]; + OutputMemoryStream stream(buffer, sizeof(buffer)); + stream.write(source_ip); + stream.write(dest_ip); + stream.write(Endian::host_to_be(flag)); + stream.write(Endian::host_to_be(len)); + uint16_t *ptr = (uint16_t*)buffer, *end = (uint16_t*)(buffer + sizeof(buffer)); - std::memcpy(buffer, &source_ip_int, sizeof(source_ip_int)); - std::memcpy(buffer + sizeof(uint32_t), &dest_ip_int, sizeof(dest_ip_int)); - while(ptr < end) - checksum += (uint32_t)*ptr++; - checksum += flag + len; + while (ptr < end) { + checksum += *ptr++; + } return checksum; } -uint32_t pseudoheader_checksum(IPv6Address source_ip, IPv6Address dest_ip, uint32_t len, uint32_t flag) { - uint32_t checksum(0); - uint16_t *ptr = (uint16_t*) source_ip.begin(); - uint16_t *end = (uint16_t*) source_ip.end(); - while(ptr < end) - checksum += (uint32_t) Endian::be_to_host(*ptr++); +uint32_t pseudoheader_checksum(IPv4Address source_ip, IPv4Address dest_ip, uint16_t len, uint16_t flag) { + return generic_pseudoheader_checksum( + source_ip, dest_ip, len, flag + ); +} - ptr = (uint16_t*) dest_ip.begin(); - end = (uint16_t*) dest_ip.end(); - while(ptr < end) - checksum += (uint32_t) Endian::be_to_host(*ptr++); - checksum += flag + len; - return checksum; +uint32_t pseudoheader_checksum(IPv6Address source_ip, IPv6Address dest_ip, uint16_t len, uint16_t flag) { + return generic_pseudoheader_checksum( + source_ip, dest_ip, len, flag + ); } uint32_t crc32(const uint8_t* data, uint32_t data_size) { diff --git a/tests/src/tcp.cpp b/tests/src/tcp.cpp index 2eeb7e9..40724e5 100644 --- a/tests/src/tcp.cpp +++ b/tests/src/tcp.cpp @@ -51,6 +51,14 @@ TEST_F(TCPTest, ChecksumCheck) { uint16_t checksum = tcp1.checksum(); PDU::serialization_type buffer = pkt1.serialize(); + EXPECT_EQ( + TCP::serialization_type( + checksum_packet, + checksum_packet + sizeof(checksum_packet) + ), + buffer + ); + EthernetII pkt2(&buffer[0], (uint32_t)buffer.size()); const TCP &tcp2 = pkt2.rfind_pdu(); EXPECT_EQ(checksum, tcp2.checksum()); diff --git a/tests/src/udp.cpp b/tests/src/udp.cpp index 649141c..11dda8d 100644 --- a/tests/src/udp.cpp +++ b/tests/src/udp.cpp @@ -11,7 +11,8 @@ using namespace Tins; class UDPTest : public testing::Test { public: - static const uint8_t expected_packet[], checksum_packet[]; + static const uint8_t expected_packet[], checksum_packet[], + checksum_packet2[]; void test_equals(const UDP& udp1, const UDP& udp2); }; @@ -28,6 +29,20 @@ const uint8_t UDPTest::checksum_packet[] = { 98, 111, 111, 107, 3, 99, 111, 109, 0, 0, 1, 0, 1 }; +const uint8_t UDPTest::checksum_packet2[] = { + 0, 20, 165, 53, 119, 224, 44, 240, 238, 33, 128, 46, 8, 0, 69, 184, 0, + 200, 9, 187, 0, 0, 63, 17, 107, 202, 192, 168, 6, 224, 198, 199, 118, + 152, 217, 252, 192, 0, 0, 180, 250, 82, 128, 0, 0, 106, 86, 129, 110, + 22, 2, 46, 39, 16, 0, 0, 7, 111, 0, 0, 34, 42, 86, 129, 110, 20, 0, 14, + 255, 229, 0, 0, 8, 234, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0 +}; + void UDPTest::test_equals(const UDP& udp1, const UDP& udp2) { EXPECT_EQ(udp1.dport(), udp2.dport()); @@ -49,14 +64,34 @@ TEST_F(UDPTest, ChecksumCheck) { EthernetII pkt1(checksum_packet, sizeof(checksum_packet)); const UDP &udp1 = pkt1.rfind_pdu(); uint16_t checksum = udp1.checksum(); - PDU::serialization_type buffer = pkt1.serialize(); + EXPECT_EQ( + UDP::serialization_type( + checksum_packet, + checksum_packet + sizeof(checksum_packet) + ), + buffer + ); + EthernetII pkt2(&buffer[0], (uint32_t)buffer.size()); const UDP &udp2 = pkt2.rfind_pdu(); EXPECT_EQ(checksum, udp2.checksum()); EXPECT_EQ(udp1.checksum(), udp2.checksum()); } +TEST_F(UDPTest, ChecksumCheck2) { + EthernetII pkt(checksum_packet2, sizeof(checksum_packet2)); + PDU::serialization_type buffer = pkt.serialize(); + EXPECT_EQ( + UDP::serialization_type( + checksum_packet2, + checksum_packet2 + sizeof(checksum_packet2) + ), + buffer + ); + EXPECT_EQ(0xfa52, pkt.rfind_pdu().checksum()); +} + TEST_F(UDPTest, CopyConstructor) { UDP udp1(expected_packet, sizeof(expected_packet)); UDP udp2(udp1);