diff --git a/include/cxxstd.h b/include/cxxstd.h index e54d7fe..3d59cff 100644 --- a/include/cxxstd.h +++ b/include/cxxstd.h @@ -30,12 +30,27 @@ #ifndef TINS_CXXSTD_H #define TINS_CXXSTD_H +#include + #ifdef __GXX_EXPERIMENTAL_CXX0X__ #define TINS_CXXSTD_GCC_FIX 1 #else #define TINS_CXXSTD_GCC_FIX 0 #endif // __GXX_EXPERIMENTAL_CXX0X__ +namespace Tins{ +namespace Internals { +template +struct smart_ptr { +#if TINS_IS_CXX11 + typedef std::unique_ptr type; +#else + typedef std::auto_ptr type; +#endif +}; +} +} + #define TINS_IS_CXX11 (__cplusplus > 199711L || TINS_CXXSTD_GCC_FIX == 1) #endif // TINS_CXXSTD_H diff --git a/include/exceptions.h b/include/exceptions.h index 73bf0ef..20f651a 100644 --- a/include/exceptions.h +++ b/include/exceptions.h @@ -48,7 +48,7 @@ public: }; /** - * \brief Exception thrown when an option is not found. + * \brief Exception thrown when a malformed packet is parsed. */ class malformed_packet : public std::runtime_error { public: @@ -56,7 +56,20 @@ public: : std::runtime_error(std::string()) { } const char* what() const throw() { - return "Option not found"; + return "Malformed packet"; + } +}; + +/** + * \brief Exception thrown when a PDU is not found when using PDU::rfind_pdu. + */ +class pdu_not_found : public std::runtime_error { +public: + pdu_not_found() + : std::runtime_error(std::string()) { } + + const char* what() const throw() { + return "PDU not found"; } }; } diff --git a/include/pdu.h b/include/pdu.h index 9f7f78e..ff5ef5a 100644 --- a/include/pdu.h +++ b/include/pdu.h @@ -35,6 +35,7 @@ #include #include "macros.h" #include "cxxstd.h" +#include "exceptions.h" /** \brief The Tins namespace. */ @@ -220,14 +221,14 @@ namespace Tins { serialization_type serialize(); /** - * \brief Find and returns the first PDU that matches the given flag. + * \brief Finds and returns the first PDU that matches the given flag. * * This method searches for the first PDU which has the same type flag as * the given one. If the first PDU matches that flag, it is returned. * If no PDU matches, 0 is returned. * \param flag The flag which being searched. */ - template + template T *find_pdu(PDUType type = T::pdu_flag) { PDU *pdu = this; while(pdu) { @@ -239,15 +240,42 @@ namespace Tins { } /** - * \brief Find and returns the first PDU that matches the given flag. + * \brief Finds and returns the first PDU that matches the given flag. * * \param flag The flag which being searched. */ - template + template const T *find_pdu(PDUType type = T::pdu_flag) const { return const_cast(this)->find_pdu(); } + /** + * \brief Finds and returns the first PDU that matches the given flag. + * + * If the PDU is not found, a pdu_not_found exception is thrown. + * + * \sa PDU::find_pdu + * + * \param flag The flag which being searched. + */ + template + T &rfind_pdu(PDUType type = T::pdu_flag) { + T *ptr = find_pdu(type); + if(!ptr) + throw pdu_not_found(); + return *ptr; + } + + /** + * \brief Finds and returns the first PDU that matches the given flag. + * + * \param flag The flag which being searched. + */ + template + const T &rfind_pdu(PDUType type = T::pdu_flag) const { + return const_cast(this)->rfind_pdu(); + } + /** * \brief Clones this packet. * diff --git a/include/sniffer.h b/include/sniffer.h index f285571..0c32ae5 100644 --- a/include/sniffer.h +++ b/include/sniffer.h @@ -272,7 +272,7 @@ namespace Tins { bool ret_val(false); LoopData *data = reinterpret_cast*>(args); try { - std::auto_ptr pdu; + Internals::smart_ptr::type pdu; if(data->iface_type == DLT_EN10MB) ret_val = call_functor(data, packet, header); else if(data->iface_type == DLT_IEEE802_11_RADIO) diff --git a/src/dns_record.cpp b/src/dns_record.cpp index 59dac33..0fd9a3b 100644 --- a/src/dns_record.cpp +++ b/src/dns_record.cpp @@ -54,7 +54,7 @@ DNSResourceRecord::DNSResourceRecord(DNSRRImpl *impl, DNSResourceRecord::DNSResourceRecord(const uint8_t *buffer, uint32_t size) { const uint8_t *buffer_end = buffer + size; - std::auto_ptr tmp_impl; + Internals::smart_ptr::type tmp_impl; if((*buffer & 0xc0)) { uint16_t offset(*reinterpret_cast(buffer)); offset = Endian::be_to_host(offset) & 0x3fff; diff --git a/src/utils.cpp b/src/utils.cpp index ca76ce6..cf7402f 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -117,11 +117,7 @@ bool resolve_hwaddr(const NetworkInterface &iface, IPv4Address ip, IPv4Address my_ip; NetworkInterface::Info info(iface.addresses()); EthernetII packet = ARP::make_arp_request(iface, ip, info.ip_addr, info.hw_addr); - #if TINS_IS_CXX11 - std::unique_ptr response(sender.send_recv(packet)); - #else - std::auto_ptr response(sender.send_recv(packet)); - #endif + Internals::smart_ptr::type response(sender.send_recv(packet)); if(response.get()) { ARP *arp_resp = response->find_pdu(); if(arp_resp) @@ -137,12 +133,7 @@ HWAddress<6> resolve_hwaddr(const NetworkInterface &iface, IPv4Address ip, Packe IPv4Address my_ip; NetworkInterface::Info info(iface.addresses()); EthernetII packet = ARP::make_arp_request(iface, ip, info.ip_addr, info.hw_addr); - #if TINS_IS_CXX11 - std::unique_ptr - #else - std::auto_ptr - #endif - response(sender.send_recv(packet)); + Internals::smart_ptr::type response(sender.send_recv(packet)); if(response.get()) { const ARP *arp_resp = response->find_pdu(); if(arp_resp) diff --git a/tests/src/dot11/data.cpp b/tests/src/dot11/data.cpp index 95354f7..69463cb 100644 --- a/tests/src/dot11/data.cpp +++ b/tests/src/dot11/data.cpp @@ -3,6 +3,7 @@ #include #include #include "dot11.h" +#include "cxxstd.h" #include "tests/dot11.h" @@ -61,12 +62,12 @@ TEST_F(Dot11DataTest, SeqNum) { TEST_F(Dot11DataTest, ClonePDU) { Dot11Data dot1(expected_packet, sizeof(expected_packet)); - std::auto_ptr dot2(dot1.clone()); + Internals::smart_ptr::type dot2(dot1.clone()); test_equals(dot1, *dot2); } TEST_F(Dot11DataTest, FromBytes) { - std::auto_ptr dot11(Dot11::from_bytes(expected_packet, sizeof(expected_packet))); + Internals::smart_ptr::type dot11(Dot11::from_bytes(expected_packet, sizeof(expected_packet))); ASSERT_TRUE(dot11.get()); const Dot11Data *inner = dot11->find_pdu(); ASSERT_TRUE(inner); @@ -94,7 +95,7 @@ TEST_F(Dot11DataTest, PCAPLoad1) { EXPECT_EQ(dot1.from_ds(), 1); EXPECT_EQ(dot1.frag_num(), 0); EXPECT_EQ(dot1.seq_num(), 1945); - std::auto_ptr dot2(dot1.clone()); + Internals::smart_ptr::type dot2(dot1.clone()); test_equals(dot1, *dot2); } diff --git a/tests/src/pdu.cpp b/tests/src/pdu.cpp index 1d29fca..1161ecc 100644 --- a/tests/src/pdu.cpp +++ b/tests/src/pdu.cpp @@ -4,6 +4,7 @@ #include #include "ip.h" #include "tcp.h" +#include "udp.h" #include "rawpdu.h" #include "pdu.h" #include "packet.h" @@ -15,6 +16,18 @@ class PDUTest : public testing::Test { public: }; +TEST_F(PDUTest, FindPDU) { + IP ip = IP("192.168.0.1") / TCP(22, 52) / RawPDU("Test"); + EXPECT_TRUE(ip.find_pdu()); + EXPECT_TRUE(ip.find_pdu()); + EXPECT_FALSE(ip.find_pdu()); + TCP &t1 = ip.rfind_pdu(); + const TCP &t2 = ip.rfind_pdu(); + (void)t1; + (void)t2; + EXPECT_THROW(ip.rfind_pdu(), pdu_not_found); +} + TEST_F(PDUTest, OperatorConcat) { std::string raw_payload = "Test"; IP ip = IP("192.168.0.1") / TCP(22, 52) / RawPDU(raw_payload);