diff --git a/include/udp.h b/include/udp.h index c0cbfa7..6afd89b 100644 --- a/include/udp.h +++ b/include/udp.h @@ -91,13 +91,15 @@ namespace Tins { */ void dport(uint16_t new_dport); - /** \brief Set the source port. + /** + * \brief Set the source port. * * \param new_sport The new source port. */ void sport(uint16_t new_sport); - /** \brief Getter for the length field. + /** + * \brief Getter for the length field. * \param new_len The new length field. * \return The length field. */ @@ -115,7 +117,8 @@ namespace Tins { */ bool matches_response(uint8_t *ptr, uint32_t total_sz); - /** \brief Returns the header size. + /** + * \brief Returns the header size. * * This metod overrides PDU::header_size. This size includes the * payload and options size. \sa PDU::header_size diff --git a/src/udp.cpp b/src/udp.cpp index 057578f..7b7549d 100644 --- a/src/udp.cpp +++ b/src/udp.cpp @@ -77,19 +77,24 @@ void UDP::write_serialization(uint8_t *buffer, uint32_t total_sz, const PDU *par assert(total_sz >= sizeof(udphdr)); #endif const Tins::IP *ip_packet = dynamic_cast(parent); + _udp.check = 0; if(inner_pdu()) length(sizeof(udphdr) + inner_pdu()->size()); else length(sizeof(udphdr)); std::memcpy(buffer, &_udp, sizeof(udphdr)); - if(!_udp.check && ip_packet) { - uint32_t checksum = Utils::pseudoheader_checksum(ip_packet->src_addr(), ip_packet->dst_addr(), size(), Constants::IP::PROTO_UDP) + - Utils::do_checksum(buffer, buffer + total_sz); + if(ip_packet) { + uint32_t checksum = Utils::pseudoheader_checksum( + ip_packet->src_addr(), + ip_packet->dst_addr(), + size(), + Constants::IP::PROTO_UDP + ) + Utils::do_checksum(buffer, buffer + total_sz); while (checksum >> 16) checksum = (checksum & 0xffff)+(checksum >> 16); - ((udphdr*)buffer)->check = Endian::host_to_be(~checksum); + _udp.check = Endian::host_to_be(~checksum); + ((udphdr*)buffer)->check = _udp.check; } - _udp.check = 0; } bool UDP::matches_response(uint8_t *ptr, uint32_t total_sz) { diff --git a/tests/src/ip.cpp b/tests/src/ip.cpp index 6ac0284..d67fc38 100644 --- a/tests/src/ip.cpp +++ b/tests/src/ip.cpp @@ -20,7 +20,7 @@ public: }; const uint8_t IPTest::expected_packet[] = { - 40, 127, 0, 32, 0, 122, 0, 67, 21, 1, 251, 103, 84, 52, 254, 5, 192, + 40, 127, 0, 32, 0, 122, 0, 67, 21, 1, 0, 0, 84, 52, 254, 5, 192, 168, 9, 43, 130, 11, 116, 106, 103, 171, 119, 171, 104, 101, 108, 0 };