diff --git a/src/udp.cpp b/src/udp.cpp index e84557c..525aa57 100644 --- a/src/udp.cpp +++ b/src/udp.cpp @@ -146,6 +146,8 @@ void UDP::write_serialization(uint8_t *buffer, uint32_t total_sz, const PDU *par checksum = (checksum & 0xffff)+(checksum >> 16); } _udp.check = ~checksum; + // If checksum is 0, it has to be set to 0xffff + _udp.check = (_udp.check == 0) ? 0xffff : _udp.check; ((udphdr*)buffer)->check = _udp.check; } diff --git a/tests/src/udp.cpp b/tests/src/udp.cpp index 11dda8d..4cf1a13 100644 --- a/tests/src/udp.cpp +++ b/tests/src/udp.cpp @@ -12,7 +12,7 @@ using namespace Tins; class UDPTest : public testing::Test { public: static const uint8_t expected_packet[], checksum_packet[], - checksum_packet2[]; + checksum_packet2[], checksum_packet3[]; void test_equals(const UDP& udp1, const UDP& udp2); }; @@ -43,6 +43,19 @@ const uint8_t UDPTest::checksum_packet2[] = { 0, 0, 0, 0 }; +const uint8_t UDPTest::checksum_packet3[] = { + 0, 20, 165, 53, 119, 224, 44, 240, 238, 33, 128, 46, 8, 0, 69, 184, 0, + 200, 127, 204, 0, 0, 28, 17, 24, 185, 192, 168, 6, 224, 198, 199, 118, + 152, 213, 50, 192, 0, 0, 180, 255, 255, 128, 0, 0, 29, 86, 130, 177, + 157, 1, 46, 0, 0, 0, 0, 7, 111, 0, 0, 52, 134, 86, 130, 177, 132, 0, + 5, 150, 253, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 +}; void UDPTest::test_equals(const UDP& udp1, const UDP& udp2) { EXPECT_EQ(udp1.dport(), udp2.dport()); @@ -92,6 +105,20 @@ TEST_F(UDPTest, ChecksumCheck2) { EXPECT_EQ(0xfa52, pkt.rfind_pdu().checksum()); } +// This checksum's 0. We should set it to 0xffff instead +TEST_F(UDPTest, ChecksumCheck3) { + EthernetII pkt(checksum_packet3, sizeof(checksum_packet3)); + PDU::serialization_type buffer = pkt.serialize(); + EXPECT_EQ( + UDP::serialization_type( + checksum_packet3, + checksum_packet3 + sizeof(checksum_packet3) + ), + buffer + ); + EXPECT_EQ(0xffff, pkt.rfind_pdu().checksum()); +} + TEST_F(UDPTest, CopyConstructor) { UDP udp1(expected_packet, sizeof(expected_packet)); UDP udp2(udp1);