From 5185c025d814528e55d4d340feb6106d12b0cbd7 Mon Sep 17 00:00:00 2001 From: stubbfel Date: Wed, 14 Jun 2017 23:53:45 +0200 Subject: [PATCH] add ip fragmentation test and fixes --- lib/libtins | 2 +- src/IpPacketFragmentation.cpp | 19 +++-- src/IpPacketFragmentation.h | 4 +- src/IpPacketFragmentation_t.h | 2 +- test/src/TestIp6Fragmentation.cpp | 99 ++++++++++++++++++++++++++ test/src/TestIp6ToIp4PacketHandler.cpp | 36 +++++----- 6 files changed, 133 insertions(+), 29 deletions(-) create mode 100644 test/src/TestIp6Fragmentation.cpp diff --git a/lib/libtins b/lib/libtins index 7d087e6..f06a508 160000 --- a/lib/libtins +++ b/lib/libtins @@ -1 +1 @@ -Subproject commit 7d087e6fb88a7c5532daea41e9b49871ad9b0816 +Subproject commit f06a508b2c61d530d4e3e576058b05b80b655bf9 diff --git a/src/IpPacketFragmentation.cpp b/src/IpPacketFragmentation.cpp index 3f41248..f6524f0 100644 --- a/src/IpPacketFragmentation.cpp +++ b/src/IpPacketFragmentation.cpp @@ -2,6 +2,8 @@ #include #include +const size_t IpPacketFragmentation::fragmentionHeadersize = sizeof (FragmentionHeaderUnion) - 1; + IpPacketFragmentation::IpPacketFragmentation(const size_t newMtu) : mtu(newMtu), idCounter() { } @@ -12,10 +14,10 @@ IpPacketFragmentation::~IpPacketFragmentation() } -void IpPacketFragmentation::addExtensionHeader(IN const uint8_t nextHeader, IN const uint8_t *startPtr, IN Tins::IPv6 & ipFragmentPdu) +void IpPacketFragmentation::addExtensionHeader(IN const uint8_t nextHeader, IN const uint8_t *startPtr, IN Tins::IPv6 & ipFragmentPdu, const size_t headersize) { - Tins::IPv6::ext_header fragmentionHeader(nextHeader, fragmentionHeadersize , startPtr); - ipFragmentPdu.add_ext_header(fragmentionHeader); + Tins::IPv6::ext_header extensionHeader(nextHeader, headersize , startPtr); + ipFragmentPdu.add_ext_header(extensionHeader); } void IpPacketFragmentation::initFragmentationHeader(FragmentionHeaderStruct* ptrFragmentionHeaderStruct) @@ -35,7 +37,7 @@ bool IpPacketFragmentation::handle(IN const Tins::PDU & pdu, IN IPacketHandler * } const size_t originPduSize = pdu.size(); - const size_t fragmentationCount = originPduSize + 1; + const size_t fragmentationCount = (originPduSize / mtu) + 1; if (fragmentationCount < 2) { return callBackHandler->handle(pdu, this); @@ -48,7 +50,7 @@ bool IpPacketFragmentation::handle(IN const Tins::PDU & pdu, IN IPacketHandler * } Tins::PDU * ipDataPdu = ipPdu->inner_pdu(); - if (ipPdu == nullptr) + if (ipDataPdu == nullptr) { return false; } @@ -91,8 +93,11 @@ bool IpPacketFragmentation::createAndForwardFragmend(IN const Tins::PDU & pdu, I return false; } - addExtensionHeader(ptrFragmentionHeaderStruct->NextHeader, ptrStartFragmentionHeader, *ipFragmentPdu); + // Tins::IPv6::ext_header fragmentionHeader(ptrFragmentionHeaderStruct->NextHeader, IpPacketFragmentation::fragmentionHeadersize , ptrStartFragmentionHeader); + // ipFragmentPdu->add_ext_header(fragmentionHeader); + addExtensionHeader(ptrFragmentionHeaderStruct->NextHeader, ptrStartFragmentionHeader, *ipFragmentPdu, IpPacketFragmentation::fragmentionHeadersize); + addExtensionHeader(Tins::IPv6::NO_NEXT_HEADER, nullptr, *ipFragmentPdu, 0); SPtrRawPDU rawFragmentPdu = std::make_shared(fragmentPayload->data(), static_cast(fragmentPayload->size())); ipFragmentPdu->inner_pdu(rawFragmentPdu.get()); - return callBackHandler->handle(pdu, this); + return callBackHandler->handle(*fragmentPdu, this); } diff --git a/src/IpPacketFragmentation.h b/src/IpPacketFragmentation.h index 445f911..8c348d1 100644 --- a/src/IpPacketFragmentation.h +++ b/src/IpPacketFragmentation.h @@ -17,8 +17,8 @@ public: private: const size_t mtu; uint32_t idCounter; - static const size_t fragmentionHeadersize = sizeof (FragmentionHeaderUnion) - 1; - static void addExtensionHeader(IN const uint8_t NextHeader, IN const uint8_t *startPtr, IN Tins::IPv6 & ipFragmentPdu); + static const size_t fragmentionHeadersize; + static void addExtensionHeader(IN const uint8_t NextHeader, IN const uint8_t *startPtr, IN Tins::IPv6 & ipFragmentPdu, IN const size_t headersize); bool createAndForwardFragmend(IN const Tins::PDU & pdu, IN const ByteVector::iterator & fragmentStart, IN const ByteVector::iterator & fragmentPosIt, IN FragmentionHeaderStruct * ptrFragmentionHeaderStruct, IN uint8_t * ptrStartFragmentionHeader, IN IPacketHandler * callBackHandler); void initFragmentationHeader(FragmentionHeaderStruct* ptrFragmentionHeaderStruct); diff --git a/src/IpPacketFragmentation_t.h b/src/IpPacketFragmentation_t.h index 7154f7e..a0e9070 100644 --- a/src/IpPacketFragmentation_t.h +++ b/src/IpPacketFragmentation_t.h @@ -30,7 +30,7 @@ struct FragmentionHeaderStruct uint32_t Identification; }; -struct FragmentionHeaderUnion +union FragmentionHeaderUnion { FragmentionHeaderStruct Structed; uint8_t Bytes[8]; diff --git a/test/src/TestIp6Fragmentation.cpp b/test/src/TestIp6Fragmentation.cpp new file mode 100644 index 0000000..2ae4e48 --- /dev/null +++ b/test/src/TestIp6Fragmentation.cpp @@ -0,0 +1,99 @@ +#include +#include +#include +#include +#include +#include +#include "IpPacketFragmentation.h" + +using namespace fakeit; + +namespace TestIp6PacketFragmentation +{ + static Tins::PDU * currentInputPdu = nullptr; + static size_t fragmentationCount = 0; + static long maxfragmentationSize = 0; + static bool firstFragment = true; + static uint32_t id = -1; + + void compareToInputPdu(const Tins::PDU & answerPdu) + { + const Tins::IPv6 * ipPdu = answerPdu.find_pdu(); + if (ipPdu == nullptr) + { + FAIL("got no ip6 packet"); + } + + Tins::PDU * ipDataPdu = ipPdu->inner_pdu(); + if (ipDataPdu == nullptr) + { + FAIL("got no ip6 inner pdu"); + } + + const long ipPayloadSize = static_cast(ipDataPdu->size()); + REQUIRE(ipPayloadSize <= maxfragmentationSize); + REQUIRE(ipPayloadSize > 0); + + const Tins::IPv6::ext_header * fragmentHeader = ipPdu->search_header(Tins::IPv6::FRAGMENT); + if (fragmentHeader == nullptr) + { + FAIL("got no fragment header"); + } + + FragmentionHeaderUnion fragmentionHeaderUnion; + FragmentionHeaderStruct * ptrFragmentionHeaderStruct = &fragmentionHeaderUnion.Structed; + std::memcpy(&fragmentionHeaderUnion.Bytes[1], fragmentHeader->data_ptr(), fragmentHeader->data_size()); + uint16_t mFlag = fragmentationCount > 1 ? 1 : 0; + if (firstFragment) + { + firstFragment = false; + id = ptrFragmentionHeaderStruct->Identification; + } + REQUIRE(fragmentHeader->option() == Tins::IPv6::FRAGMENT); + REQUIRE(ptrFragmentionHeaderStruct->Reserved == 0); + REQUIRE(ptrFragmentionHeaderStruct->Res == 0); + REQUIRE(ptrFragmentionHeaderStruct->MFlag == mFlag); + REQUIRE(ptrFragmentionHeaderStruct->Identification == id); + fragmentationCount--; + } +} + +#ifndef TEST_MTU +#define TEST_MTU 1500 +#endif + +TEST_CASE( "test Ip6PacketFragmentation", "[Ip6PacketFragmentation]" ) { + const size_t mtu = TEST_MTU; + ByteVector mtu_Payload; + for (size_t i = 0; i < mtu; i++) + { + mtu_Payload.push_back(0); + } + + Tins::RawPDU mtuPayload(mtu_Payload.data(), mtu); + IpPacketFragmentation handler(1500); + Mock mockHandler; + When(Method(mockHandler, handle)).AlwaysDo([] (IN const Tins::PDU & pdu,...) + { + TestIp6PacketFragmentation::compareToInputPdu(pdu); + return true; + }); + + Tins::EthernetII pkt = Tins::EthernetII("11:22:33:44:55:66", "66:55:44:33:22:11") / mtuPayload; + TestIp6PacketFragmentation::currentInputPdu = &pkt; + TestIp6PacketFragmentation::fragmentationCount = 2; + TestIp6PacketFragmentation::maxfragmentationSize = mtu/2; + + REQUIRE(handler.handle(pkt, nullptr) == false); + Verify(Method(mockHandler, handle)).Never(); + REQUIRE(handler.handle(pkt, &mockHandler.get()) == false); + Verify(Method(mockHandler, handle)).Never(); + + pkt = Tins::EthernetII("11:22:33:44:55:66", "66:55:44:33:22:11") / mtuPayload / Tins::IPv6("::1", "::2"); + REQUIRE(handler.handle(pkt, &mockHandler.get()) == false); + + + pkt = Tins::EthernetII("11:22:33:44:55:66", "66:55:44:33:22:11") / Tins::IPv6("::1", "::2") / mtuPayload; + REQUIRE(handler.handle(pkt, &mockHandler.get()) == true); + Verify(Method(mockHandler, handle)).Twice(); +} diff --git a/test/src/TestIp6ToIp4PacketHandler.cpp b/test/src/TestIp6ToIp4PacketHandler.cpp index 61d5db0..790ba94 100644 --- a/test/src/TestIp6ToIp4PacketHandler.cpp +++ b/test/src/TestIp6ToIp4PacketHandler.cpp @@ -44,17 +44,17 @@ namespace TestIp6ToIp4PacketHandler } TEST_CASE( "test Ip6ToIp4PacketHandler", "[Ip6ToIp4PacketHandler]" ) { - Ip6ToIp4PacketHandler handler; + Ip6ToIp4PacketHandler handler; Mock mockHandler; When(Method(mockHandler, handle)).AlwaysDo([] (IN const Tins::PDU & pdu,...) { - TestIp6ToIp4PacketHandler::compareToInputPdu(pdu); + TestIp6ToIp4PacketHandler::compareToInputPdu(pdu); return true; }); Tins::IPv6 tmpPkt = Tins::IPv6() / Tins::TCP(); Tins::EthernetII pkt = Tins::EthernetII() / tmpPkt; - TestIp6ToIp4PacketHandler::currentInputPdu = &pkt; + TestIp6ToIp4PacketHandler::currentInputPdu = &pkt; REQUIRE(handler.handle(pkt, nullptr) == false); REQUIRE(handler.handle(pkt, &mockHandler.get()) == true); Verify(Method(mockHandler, handle)).Once(); @@ -63,21 +63,21 @@ TEST_CASE( "test Ip6ToIp4PacketHandler", "[Ip6ToIp4PacketHandler]" ) { REQUIRE(handler.handle(pkt, &mockHandler.get()) == true); Verify(Method(mockHandler, handle)).Twice(); - // test pkt hashes - Tins::EthernetII * clonePkt = pkt.clone(); - REQUIRE((long) clonePkt != (long) &pkt); - Tins::PDU::serialization_type ser_pkt = pkt.serialize(); - Tins::PDU::serialization_type ser_clonepkt = clonePkt->serialize(); - REQUIRE(ser_pkt == ser_clonepkt); + // test pkt hashes + Tins::EthernetII * clonePkt = pkt.clone(); + REQUIRE((long) clonePkt != (long) &pkt); + Tins::PDU::serialization_type ser_pkt = pkt.serialize(); + Tins::PDU::serialization_type ser_clonepkt = clonePkt->serialize(); + REQUIRE(ser_pkt == ser_clonepkt); - Tins::PDU & pdu = *((Tins::PDU *)&pkt); - std::hashpdu_hash; - std::size_t h1 = pdu_hash(ser_pkt); - std::size_t h2 = pdu_hash(ser_clonepkt); + Tins::PDU & pdu = *((Tins::PDU *)&pkt); + std::hashpdu_hash; + std::size_t h1 = pdu_hash(ser_pkt); + std::size_t h2 = pdu_hash(ser_clonepkt); - REQUIRE(h1 == h2); - pkt = Tins::EthernetII("11:22:33:44:55:66", "66:55:44:33:22:11") / Tins::IPv6("::2", "::1") / Tins::TCP(); - ser_pkt = pkt.serialize(); - h1 = pdu_hash(ser_pkt); - REQUIRE(h1 != h2); + REQUIRE(h1 == h2); + pkt = Tins::EthernetII("11:22:33:44:55:66", "66:55:44:33:22:11") / Tins::IPv6("::2", "::1") / Tins::TCP(); + ser_pkt = pkt.serialize(); + h1 = pdu_hash(ser_pkt); + REQUIRE(h1 != h2); }