diff --git a/include/tins/loopback.h b/include/tins/loopback.h index d9b5629..8271bb4 100644 --- a/include/tins/loopback.h +++ b/include/tins/loopback.h @@ -86,6 +86,15 @@ public: */ PDUType pdu_type() const { return pdu_flag; } + /** + * \brief Check wether ptr points to a valid response for this PDU. + * + * \sa PDU::matches_response + * \param ptr The pointer to the buffer. + * \param total_sz The size of the buffer. + */ + bool matches_response(const uint8_t *ptr, uint32_t total_sz) const; + /** * \sa PDU::clone */ diff --git a/src/ip.cpp b/src/ip.cpp index 252ce11..6928356 100644 --- a/src/ip.cpp +++ b/src/ip.cpp @@ -432,7 +432,7 @@ bool IP::matches_response(const uint8_t *ptr, uint32_t total_sz) const { // checks for broadcast addr if((_ip.saddr == ip_ptr->daddr && (_ip.daddr == ip_ptr->saddr || dst_addr().is_broadcast())) || (dst_addr().is_broadcast() && _ip.saddr == 0)) { - uint32_t sz = std::min(_ip.ihl * sizeof(uint32_t), total_sz); + uint32_t sz = std::min(header_size(), total_sz); return inner_pdu() ? inner_pdu()->matches_response(ptr + sz, total_sz - sz) : true; } return false; diff --git a/src/loopback.cpp b/src/loopback.cpp index c5844a9..4ae2532 100644 --- a/src/loopback.cpp +++ b/src/loopback.cpp @@ -93,8 +93,7 @@ uint32_t Loopback::header_size() const { return sizeof(_family); } -void Loopback::write_serialization(uint8_t *buffer, uint32_t total_sz, const PDU *) -{ +void Loopback::write_serialization(uint8_t *buffer, uint32_t total_sz, const PDU *) { #ifdef TINS_DEBUG assert(total_sz >= sizeof(_family)); #endif @@ -106,6 +105,19 @@ void Loopback::write_serialization(uint8_t *buffer, uint32_t total_sz, const PDU *reinterpret_cast(buffer) = _family; #endif // WIN32 } + +bool Loopback::matches_response(const uint8_t *ptr, uint32_t total_sz) const { + if(total_sz < sizeof(_family)) { + return false; + } + // If there's an inner_pdu, check if the inner pdu matches. + // Otherwise, just check this loopback family. + + return inner_pdu() ? + inner_pdu()->matches_response(ptr + sizeof(_family), total_sz - sizeof(_family)) : + (_family == *reinterpret_cast(ptr)); +} + #ifdef BSD void Loopback::send(PacketSender &sender, const NetworkInterface &iface) { if(!iface) diff --git a/tests/src/CMakeLists.txt b/tests/src/CMakeLists.txt index 75d0d95..d5914ef 100644 --- a/tests/src/CMakeLists.txt +++ b/tests/src/CMakeLists.txt @@ -1,5 +1,5 @@ # Use libtins' include directories + test include directories -INCLUDE_DIRECTORIES(${PROJECT_SOURCE_DIR}/include/ ../include/) +INCLUDE_DIRECTORIES(${PROJECT_SOURCE_DIR}/include/tins/ ../include/) # Find pthread library FIND_PACKAGE(Threads REQUIRED) @@ -52,6 +52,7 @@ ADD_CUSTOM_TARGET( IPv6Test IPv6AddressTest LLCTest + LoopbackTest MatchesResponseTest NetworkInterfaceTest OfflinePacketFilterTest @@ -92,6 +93,7 @@ ADD_EXECUTABLE(IPSecTest EXCLUDE_FROM_ALL ipsec.cpp) ADD_EXECUTABLE(IPv6Test EXCLUDE_FROM_ALL ipv6.cpp) ADD_EXECUTABLE(IPv6AddressTest EXCLUDE_FROM_ALL ipv6address.cpp) ADD_EXECUTABLE(LLCTest EXCLUDE_FROM_ALL llc.cpp) +ADD_EXECUTABLE(LoopbackTest EXCLUDE_FROM_ALL loopback.cpp) ADD_EXECUTABLE(MatchesResponseTest EXCLUDE_FROM_ALL matches_response.cpp) ADD_EXECUTABLE(NetworkInterfaceTest EXCLUDE_FROM_ALL network_interface.cpp) ADD_EXECUTABLE(OfflinePacketFilterTest EXCLUDE_FROM_ALL offline_packet_filter.cpp) @@ -170,6 +172,7 @@ ADD_TEST(IPSec IPSecTest) ADD_TEST(IPv6 IPv6Test) ADD_TEST(IPv6Address IPv6AddressTest) ADD_TEST(LLC LLCTest) +ADD_TEST(Loopback LoopbackTest) ADD_TEST(MatchesResponse MatchesResponseTest) ADD_TEST(NetworkInterface NetworkInterfaceTest) ADD_TEST(OfflinePacketFilter OfflinePacketFilterTest) diff --git a/tests/src/loopback.cpp b/tests/src/loopback.cpp new file mode 100644 index 0000000..85c8eb5 --- /dev/null +++ b/tests/src/loopback.cpp @@ -0,0 +1,35 @@ +#include +#include +#include +#include +#include "macros.h" +#ifndef WIN32 + #include + #ifdef BSD + #include + #include + #include + #endif +#endif +#include "loopback.h" +#include "ip.h" +#include "tcp.h" + +using namespace std; +using namespace Tins; + +class LoopbackTest : public testing::Test { +public: + +}; + +#ifndef WIN32 +TEST_F(LoopbackTest, MatchesResponse) { + Loopback loop1 = Loopback() / IP("192.168.0.1", "192.168.0.2") / TCP(22, 21); + loop1.family(PF_INET); + Loopback loop2 = Loopback() / IP("192.168.0.2", "192.168.0.1") / TCP(21, 22); + loop2.family(PF_INET); + PDU::serialization_type buffer = loop2.serialize(); + EXPECT_TRUE(loop1.matches_response(&buffer[0], buffer.size())); +} +#endif // WIN32