diff --git a/src/udp.cpp b/src/udp.cpp index 526c0aa..f14938f 100644 --- a/src/udp.cpp +++ b/src/udp.cpp @@ -36,6 +36,7 @@ #include "constants.h" #include "utils.h" #include "ip.h" +#include "ipv6.h" #include "rawpdu.h" #include "exceptions.h" @@ -78,13 +79,13 @@ void UDP::write_serialization(uint8_t *buffer, uint32_t total_sz, const PDU *par #ifdef TINS_DEBUG assert(total_sz >= sizeof(udphdr)); #endif - const Tins::IP *ip_packet = dynamic_cast(parent); _udp.check = 0; if(inner_pdu()) length(sizeof(udphdr) + inner_pdu()->size()); else length(sizeof(udphdr)); std::memcpy(buffer, &_udp, sizeof(udphdr)); + const Tins::IP *ip_packet = dynamic_cast(parent); if(ip_packet) { uint32_t checksum = Utils::pseudoheader_checksum( ip_packet->src_addr(), @@ -97,6 +98,21 @@ void UDP::write_serialization(uint8_t *buffer, uint32_t total_sz, const PDU *par _udp.check = Endian::host_to_be(~checksum); ((udphdr*)buffer)->check = _udp.check; } + else { + const Tins::IPv6 *ip6_packet = dynamic_cast(parent); + if(ip6_packet) { + uint32_t checksum = Utils::pseudoheader_checksum( + ip6_packet->src_addr(), + ip6_packet->dst_addr(), + size(), + Constants::IP::PROTO_UDP + ) + Utils::do_checksum(buffer, buffer + total_sz); + while (checksum >> 16) + checksum = (checksum & 0xffff)+(checksum >> 16); + _udp.check = Endian::host_to_be(~checksum); + ((udphdr*)buffer)->check = _udp.check; + } + } } bool UDP::matches_response(const uint8_t *ptr, uint32_t total_sz) const {