diff --git a/include/tins/ip_reassembler.h b/include/tins/ip_reassembler.h index 51a521f..1fc9c77 100644 --- a/include/tins/ip_reassembler.h +++ b/include/tins/ip_reassembler.h @@ -35,13 +35,13 @@ #include "pdu.h" #include "macros.h" #include "ip_address.h" +#include "ip.h" namespace Tins { /** * \cond */ -class IP; namespace Internals { class IPv4Fragment { public: @@ -74,6 +74,7 @@ public: void add_fragment(IP* ip); bool is_complete() const; PDU* allocate_pdu() const; + const IP& first_fragment() const; private: typedef std::vector fragments_type; @@ -81,9 +82,11 @@ private: bool extract_more_frag(const IP* ip); fragments_type fragments_; + size_t received_size_; + size_t total_size_; + IP first_fragment_; bool received_end_; uint8_t transport_proto_; - size_t received_size_, total_size_; }; } // namespace Internals diff --git a/src/ip_reassembler.cpp b/src/ip_reassembler.cpp index 577389f..41a808f 100644 --- a/src/ip_reassembler.cpp +++ b/src/ip_reassembler.cpp @@ -38,11 +38,17 @@ namespace Tins { namespace Internals { IPv4Stream::IPv4Stream() -: received_end_(false), transport_proto_(0xff), received_size_(), total_size_() { +: received_size_(), total_size_(), received_end_(false), transport_proto_(0xff) { } void IPv4Stream::add_fragment(IP* ip) { + if (fragments_.empty()) { + // Release the inner PDU, store this first fragment and restore the inner PDU + PDU* inner_pdu = ip->release_inner_pdu(); + first_fragment_ = *ip; + ip->inner_pdu(inner_pdu); + } fragments_type::iterator it = fragments_.begin(); uint16_t offset = extract_offset(ip); while (it != fragments_.end() && offset > it->offset()) { @@ -87,6 +93,10 @@ PDU* IPv4Stream::allocate_pdu() const { ); } +const IP& IPv4Stream::first_fragment() const { + return first_fragment_; +} + uint16_t IPv4Stream::extract_offset(const IP* ip) { return ip->fragment_offset() * 8; } @@ -114,6 +124,9 @@ IPv4Reassembler::PacketStatus IPv4Reassembler::process(PDU& pdu) { stream.add_fragment(ip); if (stream.is_complete()) { PDU* pdu = stream.allocate_pdu(); + // Use all field values from the first fragment + *ip = stream.first_fragment(); + // Erase this stream, since it's already assembled streams_.erase(key); // The packet is corrupt diff --git a/tests/src/ip_reassembler_test.cpp b/tests/src/ip_reassembler_test.cpp index a673d6d..ed00cf2 100644 --- a/tests/src/ip_reassembler_test.cpp +++ b/tests/src/ip_reassembler_test.cpp @@ -8,6 +8,10 @@ #include "ip.h" #include "rawpdu.h" +using std::vector; +using std::pair; +using std::make_pair; + using namespace Tins; class IPv4ReassemblerTest : public testing::Test { @@ -15,7 +19,7 @@ public: static const uint8_t packets[][1514]; static const size_t packet_sizes[], orderings[][11]; - void test_packets(const std::vector >& vt); + void test_packets(const vector >& vt); }; const uint8_t IPv4ReassemblerTest::packets[][1514] = { @@ -43,30 +47,37 @@ const size_t IPv4ReassemblerTest::orderings[][11] = { { 1, 9, 8, 5, 4, 2, 0, 7, 6, 3, 10 } }; -void IPv4ReassemblerTest::test_packets(const std::vector >& vt) { +void IPv4ReassemblerTest::test_packets(const vector >& vt) { IPv4Reassembler reassembler; for(size_t i = 0; i < vt.size(); ++i) { EthernetII eth(vt[i].first, (uint32_t)vt[i].second); + if (i != 0) { + eth.rfind_pdu().ttl(32); + } IPv4Reassembler::PacketStatus status = reassembler.process(eth); EXPECT_NE(IPv4Reassembler::NOT_FRAGMENTED, status); - if(status == IPv4Reassembler::REASSEMBLED) { + if (status == IPv4Reassembler::REASSEMBLED) { ASSERT_EQ(static_cast(vt.size() - 1), i); ASSERT_TRUE(eth.find_pdu() != NULL); + const IP& ip = eth.rfind_pdu(); + EXPECT_EQ(64, ip.ttl()); + RawPDU* raw = eth.find_pdu(); ASSERT_TRUE(raw != NULL); ASSERT_EQ(15000ULL, raw->payload().size()); } - else if(status == IPv4Reassembler::FRAGMENTED) + else if (status == IPv4Reassembler::FRAGMENTED) { EXPECT_NE(vt.size() - 1, i); + } } } TEST_F(IPv4ReassemblerTest, Reassemble) { for(size_t i = 0; i < 3; ++i) { - std::vector > vt; + vector > vt; for(size_t j = 0; j < 11; ++j) { vt.push_back( - std::make_pair( + make_pair( packets[orderings[i][j]], packet_sizes[orderings[i][j]] )