diff --git a/src/icmpv6.cpp b/src/icmpv6.cpp index ce00043..cadce1c 100644 --- a/src/icmpv6.cpp +++ b/src/icmpv6.cpp @@ -198,6 +198,7 @@ void ICMPv6::write_serialization(uint8_t *buffer, uint32_t total_sz, const PDU * assert(total_sz >= header_size()); #endif icmp6hdr* ptr_header = (icmp6hdr*)buffer; + _header.cksum = 0; std::memcpy(buffer, &_header, sizeof(_header)); buffer += sizeof(_header); total_sz -= sizeof(_header); @@ -224,19 +225,18 @@ void ICMPv6::write_serialization(uint8_t *buffer, uint32_t total_sz, const PDU * #endif buffer = write_option(*it, buffer); } - if(!_header.cksum) { - const Tins::IPv6 *ipv6 = tins_cast(parent); - if(ipv6) { - uint32_t checksum = Utils::pseudoheader_checksum( - ipv6->src_addr(), - ipv6->dst_addr(), - size(), - Constants::IP::PROTO_ICMPV6 - ) + Utils::do_checksum((uint8_t*)ptr_header, buffer); - while (checksum >> 16) - checksum = (checksum & 0xffff) + (checksum >> 16); - ptr_header->cksum = Endian::host_to_be(~checksum); - } + const Tins::IPv6 *ipv6 = tins_cast(parent); + if(ipv6) { + uint32_t checksum = Utils::pseudoheader_checksum( + ipv6->src_addr(), + ipv6->dst_addr(), + size(), + Constants::IP::PROTO_ICMPV6 + ) + Utils::do_checksum((uint8_t*)ptr_header, buffer); + while (checksum >> 16) + checksum = (checksum & 0xffff) + (checksum >> 16); + this->checksum(~checksum); + ptr_header->cksum = _header.cksum; } } diff --git a/src/utils.cpp b/src/utils.cpp index 856b181..8b81dd0 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -184,12 +184,16 @@ uint32_t pseudoheader_checksum(IPv4Address source_ip, IPv4Address dest_ip, uint3 } uint32_t pseudoheader_checksum(IPv6Address source_ip, IPv6Address dest_ip, uint32_t len, uint32_t flag) { - uint32_t checksum = 0; - IPv6Address::const_iterator it; - for(it = source_ip.begin(); it != source_ip.end(); ++it) - checksum += *it; - for(it = dest_ip.begin(); it != dest_ip.end(); ++it) - checksum += *it; + uint32_t checksum(0); + uint16_t *ptr = (uint16_t*) source_ip.begin(); + uint16_t *end = (uint16_t*) source_ip.end(); + while(ptr < end) + checksum += (uint32_t) Endian::be_to_host(*ptr++); + + ptr = (uint16_t*) dest_ip.begin(); + end = (uint16_t*) dest_ip.end(); + while(ptr < end) + checksum += (uint32_t) Endian::be_to_host(*ptr++); checksum += flag + len; return checksum; } diff --git a/tests/src/icmpv6.cpp b/tests/src/icmpv6.cpp index d8dc7bb..369d24c 100644 --- a/tests/src/icmpv6.cpp +++ b/tests/src/icmpv6.cpp @@ -3,6 +3,7 @@ #include #include #include "icmpv6.h" +#include "ethernetII.h" #include "ip.h" #include "tcp.h" #include "utils.h" @@ -14,6 +15,7 @@ class ICMPv6Test : public testing::Test { public: static const uint8_t expected_packet[]; static const uint8_t expected_packet1[]; + static const uint8_t expected_packet2[]; void test_equals(const ICMPv6 &icmp1, const ICMPv6 &icmp2); }; @@ -31,6 +33,15 @@ const uint8_t ICMPv6Test::expected_packet1[] = { 0, 0, 0, 0, 0, 0, 0 }; +const uint8_t ICMPv6Test::expected_packet2[] = { + 0, 96, 151, 7, 105, 234, 0, 0, 134, 5, 128, 218, 134, 221, 96, 0, + 0, 0, 0, 32, 58, 255, 254, 128, 0, 0, 0, 0, 0, 0, 2, 0, 134, 255 + , 254, 5, 128, 218, 254, 128, 0, 0, 0, 0, 0, 0, 2, 96, 151, 255, + 254, 7, 105, 234, 135, 0, 0, 0, 0, 0, 0, 0, 254, 128, 0, 0, 0 + , 0, 0, 0, 2, 96, 151, 255, 254, 7, 105, 234, 1, 1, 0, 0, 134, 5, + 128, 218 +}; + TEST_F(ICMPv6Test, Constructor) { ICMPv6 icmp; EXPECT_EQ(icmp.type(), ICMPv6::ECHO_REQUEST); @@ -452,3 +463,10 @@ TEST_F(ICMPv6Test, SpoofedOptions) { EXPECT_EQ(3U, pdu.options().size()); EXPECT_EQ(pdu.serialize().size(), pdu.size()); } + +TEST_F(ICMPv6Test, ChecksumCalculation) { + EthernetII eth(expected_packet2, sizeof(expected_packet2)); + EthernetII::serialization_type serialized = eth.serialize(); + const ICMPv6& icmp = eth.rfind_pdu(); + EXPECT_EQ(0x68bd, icmp.checksum()); +} diff --git a/tests/src/ipv6.cpp b/tests/src/ipv6.cpp index d92093a..1a7f9a4 100644 --- a/tests/src/ipv6.cpp +++ b/tests/src/ipv6.cpp @@ -23,7 +23,7 @@ public: const uint8_t IPv6Test::expected_packet1[] = { 105, 168, 39, 52, 0, 40, 6, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 198, 140, - 0, 80, 104, 72, 3, 12, 0, 0, 0, 0, 160, 2, 127, 240, 0, 48, 0, 0, 2, + 0, 80, 104, 72, 3, 12, 0, 0, 0, 0, 160, 2, 127, 240, 183, 120, 0, 0, 2, 4, 63, 248, 4, 2, 8, 10, 0, 132, 163, 156, 0, 0, 0, 0, 1, 3, 3, 7 }; @@ -115,10 +115,13 @@ TEST_F(IPv6Test, ConstructorFromBuffer2) { } TEST_F(IPv6Test, Serialize) { - IPv6 ip1(expected_packet2, sizeof(expected_packet2)); + IPv6 ip1(expected_packet1, sizeof(expected_packet1)); IPv6::serialization_type buffer = ip1.serialize(); - ASSERT_EQ(buffer.size(), sizeof(expected_packet2)); - EXPECT_TRUE(std::equal(buffer.begin(), buffer.end(), expected_packet2)); + ASSERT_EQ(buffer.size(), sizeof(expected_packet1)); + EXPECT_EQ( + IPv6::serialization_type(expected_packet1, expected_packet1 + sizeof(expected_packet1)), + buffer + ); IPv6 ip2(&buffer[0], buffer.size()); test_equals(ip1, ip2); }