diff --git a/include/crypto.h b/include/crypto.h index 99d704f..6b55855 100644 --- a/include/crypto.h +++ b/include/crypto.h @@ -49,7 +49,7 @@ namespace Crypto { /** * \cond */ - class RC4Key; + struct RC4Key; #ifdef HAVE_WPA2_DECRYPTION namespace WPA2 { class invalid_handshake : public std::exception { diff --git a/include/exceptions.h b/include/exceptions.h index 99873cf..338ce48 100644 --- a/include/exceptions.h +++ b/include/exceptions.h @@ -145,6 +145,16 @@ public: return "Malformed option"; } }; + +/** + * \brief Exception thrown when a call to tins_cast fails. + */ +class bad_tins_cast : public std::exception { +public: + const char *what() const throw() { + return "Bad Tins cast"; + } +}; } #endif // TINS_EXCEPTIONS_H diff --git a/include/pdu.h b/include/pdu.h index e1afb55..f6f5a49 100644 --- a/include/pdu.h +++ b/include/pdu.h @@ -456,6 +456,34 @@ namespace Tins { *lop /= rop; return lop; } + + namespace Internals { + template + struct remove_pointer { + typedef T type; + }; + + template + struct remove_pointer { + typedef T type; + }; + } + + template + T tins_cast(U *pdu) { + typedef typename Internals::remove_pointer::type TrueT; + return pdu && (TrueT::pdu_flag == pdu->pdu_type()) ? + static_cast(pdu) : + 0; + } + + template + T &tins_cast(U &pdu) { + T *ptr = tins_cast(&pdu); + if(!ptr) + throw bad_tins_cast(); + return *ptr; + } } #endif // TINS_PDU_H diff --git a/src/icmpv6.cpp b/src/icmpv6.cpp index 5cd4084..4086c74 100644 --- a/src/icmpv6.cpp +++ b/src/icmpv6.cpp @@ -225,7 +225,7 @@ void ICMPv6::write_serialization(uint8_t *buffer, uint32_t total_sz, const PDU * buffer = write_option(*it, buffer); } if(!_header.cksum) { - const Tins::IPv6 *ipv6 = dynamic_cast(parent); + const Tins::IPv6 *ipv6 = tins_cast(parent); if(ipv6) { uint32_t checksum = Utils::pseudoheader_checksum( ipv6->src_addr(), diff --git a/src/loopback.cpp b/src/loopback.cpp index 763200f..124255d 100644 --- a/src/loopback.cpp +++ b/src/loopback.cpp @@ -96,9 +96,9 @@ void Loopback::write_serialization(uint8_t *buffer, uint32_t total_sz, const PDU #ifdef TINS_DEBUG assert(total_sz >= sizeof(_family)); #endif - if(dynamic_cast(inner_pdu())) + if(tins_cast(inner_pdu())) _family = PF_INET; - else if(dynamic_cast(inner_pdu())) + else if(tins_cast(inner_pdu())) _family = PF_LLC; *reinterpret_cast(buffer) = _family; } diff --git a/src/radiotap.cpp b/src/radiotap.cpp index 5cb8936..063a38a 100644 --- a/src/radiotap.cpp +++ b/src/radiotap.cpp @@ -276,7 +276,7 @@ void RadioTap::send(PacketSender &sender, const NetworkInterface &iface) { addr.sll_halen = 6; addr.sll_ifindex = iface.id(); - Tins::Dot11 *wlan = dynamic_cast(inner_pdu()); + const Tins::Dot11 *wlan = tins_cast(inner_pdu()); if(wlan) { Tins::Dot11::address_type dot11_addr(wlan->addr1()); std::copy(dot11_addr.begin(), dot11_addr.end(), addr.sll_addr); diff --git a/src/tcp.cpp b/src/tcp.cpp index 1b6d516..6aa18c5 100644 --- a/src/tcp.cpp +++ b/src/tcp.cpp @@ -309,7 +309,7 @@ void TCP::write_serialization(uint8_t *buffer, uint32_t total_sz, const PDU *par memcpy(tcp_start, &_tcp, sizeof(tcphdr)); - const Tins::IP *ip_packet = dynamic_cast(parent); + const Tins::IP *ip_packet = tins_cast(parent); if(ip_packet) { uint32_t check = Utils::pseudoheader_checksum(ip_packet->src_addr(), ip_packet->dst_addr(), @@ -321,7 +321,7 @@ void TCP::write_serialization(uint8_t *buffer, uint32_t total_sz, const PDU *par ((tcphdr*)tcp_start)->check = _tcp.check; } else { - const Tins::IPv6 *ipv6_packet = dynamic_cast(parent); + const Tins::IPv6 *ipv6_packet = tins_cast(parent); if(ipv6_packet) { uint32_t check = Utils::pseudoheader_checksum(ipv6_packet->src_addr(), ipv6_packet->dst_addr(), diff --git a/src/udp.cpp b/src/udp.cpp index f14938f..72c15f5 100644 --- a/src/udp.cpp +++ b/src/udp.cpp @@ -85,7 +85,7 @@ void UDP::write_serialization(uint8_t *buffer, uint32_t total_sz, const PDU *par else length(sizeof(udphdr)); std::memcpy(buffer, &_udp, sizeof(udphdr)); - const Tins::IP *ip_packet = dynamic_cast(parent); + const Tins::IP *ip_packet = tins_cast(parent); if(ip_packet) { uint32_t checksum = Utils::pseudoheader_checksum( ip_packet->src_addr(), @@ -99,7 +99,7 @@ void UDP::write_serialization(uint8_t *buffer, uint32_t total_sz, const PDU *par ((udphdr*)buffer)->check = _udp.check; } else { - const Tins::IPv6 *ip6_packet = dynamic_cast(parent); + const Tins::IPv6 *ip6_packet = tins_cast(parent); if(ip6_packet) { uint32_t checksum = Utils::pseudoheader_checksum( ip6_packet->src_addr(), diff --git a/tests/src/pdu.cpp b/tests/src/pdu.cpp index 1161ecc..47aaa59 100644 --- a/tests/src/pdu.cpp +++ b/tests/src/pdu.cpp @@ -67,3 +67,15 @@ TEST_F(PDUTest, OperatorConcatOnPacket) { ASSERT_EQ(raw->payload_size(), raw_payload.size()); EXPECT_TRUE(std::equal(raw->payload().begin(), raw->payload().end(), raw_payload.begin())); } + +TEST_F(PDUTest, TinsCast) { + PDU *null_pdu = 0; + TCP tcp; + PDU *pdu = &tcp; + EXPECT_EQ(tins_cast(pdu), &tcp); + EXPECT_EQ(tins_cast(pdu), &tcp); + EXPECT_EQ(tins_cast(null_pdu), null_pdu); + EXPECT_EQ(tins_cast(pdu), null_pdu); + EXPECT_THROW(tins_cast(*pdu), bad_tins_cast); +} +