From 48c068b84ac106d8a2c9b7e7cac712921bcfff7a Mon Sep 17 00:00:00 2001 From: Matias Fontanini Date: Sat, 13 Feb 2016 11:23:08 -0800 Subject: [PATCH] Add callbacks for stream termination events --- include/tins/tcp_ip/flow.h | 31 +++++++++++++-------- include/tins/tcp_ip/stream.h | 12 ++++---- include/tins/tcp_ip/stream_follower.h | 40 +++++++++++++++++++++++---- src/tcp_ip/flow.cpp | 30 ++++++++++++++------ src/tcp_ip/stream.cpp | 4 +-- src/tcp_ip/stream_follower.cpp | 31 ++++++++++++++++----- tests/src/tcp_ip.cpp | 9 +++++- 7 files changed, 116 insertions(+), 41 deletions(-) diff --git a/include/tins/tcp_ip/flow.h b/include/tins/tcp_ip/flow.h index 69b0858..d48a1af 100644 --- a/include/tins/tcp_ip/flow.h +++ b/include/tins/tcp_ip/flow.h @@ -106,7 +106,7 @@ public: */ typedef std::function out_of_order_callback_type; + const payload_type&)> flow_packet_callback_type; /** * Construct a Flow from an IPv4 address @@ -146,7 +146,7 @@ public: * * \param callback The callback to be executed */ - void out_of_order_callback(const out_of_order_callback_type& callback); + void out_of_order_callback(const flow_packet_callback_type& callback); /** * \brief Processes a packet. @@ -188,54 +188,59 @@ public: bool packet_belongs(const PDU& packet) const; /** - * \brief Getter for the IPv4 destination address + * \brief Retrieves the IPv4 destination address * * Note that it's only safe to execute this method if is_v6() == false */ IPv4Address dst_addr_v4() const; /** - * \brief Getter for the IPv6 destination address + * \brief Retrieves the IPv6 destination address * * Note that it's only safe to execute this method if is_v6() == true */ IPv6Address dst_addr_v6() const; /** - * Getter for this flow's destination port + * Retrieves this flow's destination port */ uint16_t dport() const; /** - * Getter for this flow's payload (const) + * Retrieves this flow's payload (const) */ const payload_type& payload() const; /** - * Getter for this flow's destination port + * Retrieves this flow's destination port */ payload_type& payload(); /** - * Getter for this flow's state + * Retrieves this flow's state */ State state() const; /** - * Getter for this flow's sequence number + * Retrieves this flow's sequence number */ uint32_t sequence_number() const; /** - * Getter for this flow's buffered payload (const) + * Retrieves this flow's buffered payload (const) */ const buffered_payload_type& buffered_payload() const; /** - * Getter for this flow's buffered payload + * Retrieves this flow's buffered payload */ buffered_payload_type& buffered_payload(); + /** + * Retrieves this flow's total buffered bytes + */ + uint32_t total_buffered_bytes() const; + /** * Sets the state of this flow * @@ -277,14 +282,16 @@ private: void store_payload(uint32_t seq, payload_type payload); buffered_payload_type::iterator erase_iterator(buffered_payload_type::iterator iter); void update_state(const TCP& tcp); + void initialize(); payload_type payload_; buffered_payload_type buffered_payload_; uint32_t seq_number_; + uint32_t total_buffered_bytes_; std::array dest_address_; uint16_t dest_port_; data_available_callback_type on_data_callback_; - out_of_order_callback_type on_out_of_order_callback_; + flow_packet_callback_type on_out_of_order_callback_; State state_; int mss_; flags flags_; diff --git a/include/tins/tcp_ip/stream.h b/include/tins/tcp_ip/stream.h index 6e5f0b0..87256ee 100644 --- a/include/tins/tcp_ip/stream.h +++ b/include/tins/tcp_ip/stream.h @@ -85,13 +85,13 @@ public: typedef std::function stream_callback_type; /** - * The type used for callbacks + * The type used for packet-triggered callbacks * * /sa Flow::buffering_callback */ typedef std::function out_of_order_callback_type; + const payload_type&)> stream_packet_callback_type; /** * The type used to store hardware addresses @@ -279,7 +279,7 @@ public: * \sa Flow::buffering_callback * \param callback The callback to be set */ - void client_out_of_order_callback(const out_of_order_callback_type& callback); + void client_out_of_order_callback(const stream_packet_callback_type& callback); /** * \brief Sets the callback to be executed when there's new buffered @@ -288,7 +288,7 @@ public: * \sa Flow::buffering_callback * \param callback The callback to be set */ - void server_out_of_order_callback(const out_of_order_callback_type& callback); + void server_out_of_order_callback(const stream_packet_callback_type& callback); /** * \brief Indicates that the data packets sent by the client should be @@ -352,8 +352,8 @@ private: stream_callback_type on_stream_closed_; stream_callback_type on_client_data_callback_; stream_callback_type on_server_data_callback_; - out_of_order_callback_type on_client_out_of_order_callback_; - out_of_order_callback_type on_server_out_of_order_callback_; + stream_packet_callback_type on_client_out_of_order_callback_; + stream_packet_callback_type on_server_out_of_order_callback_; hwaddress_type client_hw_addr_; hwaddress_type server_hw_addr_; timestamp_type create_time_; diff --git a/include/tins/tcp_ip/stream_follower.h b/include/tins/tcp_ip/stream_follower.h index 5ac2209..da6732f 100644 --- a/include/tins/tcp_ip/stream_follower.h +++ b/include/tins/tcp_ip/stream_follower.h @@ -79,6 +79,21 @@ public: */ typedef Stream::stream_callback_type stream_callback_type; + /** + * Enum to indicate the reason why a stream was terminated + */ + enum TerminationReason { + TIMEOUT, ///< The stream was terminated due to a timeout + BUFFERED_DATA ///< The stream was terminated because it had too much buffered data + }; + + /** + * \brief The type used for stream termination callbacks + * + * \sa StreamFollower::stream_termination_callback + */ + typedef std::function stream_termination_callback_type; + /** * Default constructor */ @@ -116,6 +131,19 @@ public: */ void new_stream_callback(const stream_callback_type& callback); + /** + * \brief Sets the stream termination callback + * + * A stream is terminated when either: + * + * * It contains too much buffered data. + * * No packets have been seen for some time interval. + * + * \param callback The callback to be executed on stream termination + * \sa StreamFollower::stream_keep_alive + */ + void stream_termination_callback(const stream_termination_callback_type& callback); + /** * \brief Sets the maximum time a stream will be followed without capturing * packets that belong to it. @@ -135,8 +163,8 @@ public: * \param server_addr The server's address * \param server_addr The server's port */ - Stream& find_stream(IPv4Address client_addr, uint16_t client_port, - IPv4Address server_addr, uint16_t server_port); + Stream& find_stream(const IPv4Address& client_addr, uint16_t client_port, + const IPv4Address& server_addr, uint16_t server_port); /** * Finds the stream identified by the provided arguments. @@ -146,14 +174,14 @@ public: * \param server_addr The server's address * \param server_addr The server's port */ - Stream& find_stream(IPv6Address client_addr, uint16_t client_port, - IPv6Address server_addr, uint16_t server_port); + Stream& find_stream(const IPv6Address& client_addr, uint16_t client_port, + const IPv6Address& server_addr, uint16_t server_port); private: typedef std::array address_type; typedef Stream::timestamp_type timestamp_type; static const size_t DEFAULT_MAX_BUFFERED_CHUNKS; - static const timestamp_type DEFAULT_CLEANUP_INTERVAL; + static const uint32_t DEFAULT_MAX_BUFFERED_BYTES; static const timestamp_type DEFAULT_KEEP_ALIVE; struct stream_id { @@ -181,7 +209,9 @@ private: streams_type streams_; stream_callback_type on_new_connection_; + stream_termination_callback_type on_stream_termination_; size_t max_buffered_chunks_; + uint32_t max_buffered_bytes_; timestamp_type last_cleanup_; timestamp_type stream_keep_alive_; bool attach_to_flows_; diff --git a/src/tcp_ip/flow.cpp b/src/tcp_ip/flow.cpp index d425e29..52f2df9 100644 --- a/src/tcp_ip/flow.cpp +++ b/src/tcp_ip/flow.cpp @@ -75,32 +75,40 @@ int seq_compare(uint32_t seq1, uint32_t seq2) { Flow::Flow(const IPv4Address& dest_address, uint16_t dest_port, uint32_t sequence_number) -: seq_number_(sequence_number), dest_port_(dest_port), state_(UNKNOWN), mss_(-1) { +: seq_number_(sequence_number), dest_port_(dest_port) { OutputMemoryStream output(dest_address_.data(), dest_address_.size()); output.write(dest_address); flags_.is_v6 = false; + initialize(); } Flow::Flow(const IPv6Address& dest_address, uint16_t dest_port, uint32_t sequence_number) -: seq_number_(sequence_number), dest_port_(dest_port), state_(UNKNOWN), mss_(-1) { +: seq_number_(sequence_number), dest_port_(dest_port) { OutputMemoryStream output(dest_address_.data(), dest_address_.size()); output.write(dest_address); flags_.is_v6 = true; + initialize(); +} + +void Flow::initialize() { + total_buffered_bytes_ = 0; + state_ = UNKNOWN; + mss_ = -1; } void Flow::data_callback(const data_available_callback_type& callback) { on_data_callback_ = callback; } -void Flow::out_of_order_callback(const out_of_order_callback_type& callback) { +void Flow::out_of_order_callback(const flow_packet_callback_type& callback) { on_out_of_order_callback_ = callback; } void Flow::process_packet(PDU& pdu) { TCP* tcp = pdu.find_pdu(); RawPDU* raw = pdu.find_pdu(); - // If we sent a packet with RST or FIN on, this flow is done + // Update the internal state first if (tcp) { update_state(*tcp); } @@ -142,6 +150,8 @@ void Flow::process_packet(PDU& pdu) { if (comparison > 0) { // Then slice it payload_type& payload = iter->second; + // First update this counter + total_buffered_bytes_ -= payload.size(); payload.erase( payload.begin(), payload.begin() + (seq_number_ - iter->first) @@ -164,10 +174,6 @@ void Flow::process_packet(PDU& pdu) { seq_number_ += iter->second.size(); iter = erase_iterator(iter); added_some = true; - // If we don't have any other payload, we're done - if (buffered_payload_.empty()) { - break; - } } } if (added_some) { @@ -182,9 +188,12 @@ void Flow::store_payload(uint32_t seq, payload_type payload) { buffered_payload_type::iterator iter = buffered_payload_.find(seq); // New segment, store it if (iter == buffered_payload_.end()) { + total_buffered_bytes_ += payload.size(); buffered_payload_.insert(make_pair(seq, move(payload))); } else if (iter->second.size() < payload.size()) { + // Increment by the diff between sizes + total_buffered_bytes_ += (payload.size() - iter->second.size()); // If we already have payload on this position but it's a shorter // chunk than the new one, replace it iter->second = move(payload); @@ -193,6 +202,7 @@ void Flow::store_payload(uint32_t seq, payload_type payload) { Flow::buffered_payload_type::iterator Flow::erase_iterator(buffered_payload_type::iterator iter) { buffered_payload_type::iterator output = iter; + total_buffered_bytes_ -= iter->second.size(); ++output; buffered_payload_.erase(iter); if (output == buffered_payload_.end()) { @@ -282,6 +292,10 @@ Flow::buffered_payload_type& Flow::buffered_payload() { return buffered_payload_; } +uint32_t Flow::total_buffered_bytes() const { + return total_buffered_bytes_; +} + Flow::payload_type& Flow::payload() { return payload_; } diff --git a/src/tcp_ip/stream.cpp b/src/tcp_ip/stream.cpp index cc902f9..1acc742 100644 --- a/src/tcp_ip/stream.cpp +++ b/src/tcp_ip/stream.cpp @@ -116,11 +116,11 @@ void Stream::server_data_callback(const stream_callback_type& callback) { on_server_data_callback_ = callback; } -void Stream::client_out_of_order_callback(const out_of_order_callback_type& callback) { +void Stream::client_out_of_order_callback(const stream_packet_callback_type& callback) { on_client_out_of_order_callback_ = callback; } -void Stream::server_out_of_order_callback(const out_of_order_callback_type& callback) { +void Stream::server_out_of_order_callback(const stream_packet_callback_type& callback) { on_server_out_of_order_callback_ = callback; } diff --git a/src/tcp_ip/stream_follower.cpp b/src/tcp_ip/stream_follower.cpp index 3f643aa..54ea341 100644 --- a/src/tcp_ip/stream_follower.cpp +++ b/src/tcp_ip/stream_follower.cpp @@ -62,10 +62,12 @@ namespace Tins { namespace TCPIP { const size_t StreamFollower::DEFAULT_MAX_BUFFERED_CHUNKS = 512; +const uint32_t StreamFollower::DEFAULT_MAX_BUFFERED_BYTES = 3 * 1024 * 1024; // 3MB const StreamFollower::timestamp_type StreamFollower::DEFAULT_KEEP_ALIVE = minutes(5); StreamFollower::StreamFollower() -: max_buffered_chunks_(DEFAULT_MAX_BUFFERED_CHUNKS), last_cleanup_(0), +: max_buffered_chunks_(DEFAULT_MAX_BUFFERED_CHUNKS), + max_buffered_bytes_(DEFAULT_MAX_BUFFERED_BYTES), last_cleanup_(0), stream_keep_alive_(DEFAULT_KEEP_ALIVE), attach_to_flows_(false) { } @@ -118,7 +120,15 @@ void StreamFollower::process_packet(PDU& packet, const timestamp_type& ts) { stream.process_packet(packet, ts); size_t total_chunks = stream.client_flow().buffered_payload().size() + stream.server_flow().buffered_payload().size(); - if (stream.is_finished() || total_chunks > max_buffered_chunks_) { + uint32_t total_buffered_bytes = stream.client_flow().total_buffered_bytes() + + stream.server_flow().total_buffered_bytes(); + bool terminate_stream = total_chunks > max_buffered_chunks_ || + total_buffered_bytes > max_buffered_bytes_; + if (stream.is_finished() || terminate_stream) { + // If we're terminating the stream, execute the termination callback + if (terminate_stream && on_stream_termination_) { + on_stream_termination_(stream, BUFFERED_DATA); + } streams_.erase(iter); } } @@ -131,15 +141,19 @@ void StreamFollower::new_stream_callback(const stream_callback_type& callback) { on_new_connection_ = callback; } -Stream& StreamFollower::find_stream(IPv4Address client_addr, uint16_t client_port, - IPv4Address server_addr, uint16_t server_port) { +void StreamFollower::stream_termination_callback(const stream_termination_callback_type& callback) { + on_stream_termination_ = callback; +} + +Stream& StreamFollower::find_stream(const IPv4Address& client_addr, uint16_t client_port, + const IPv4Address& server_addr, uint16_t server_port) { stream_id identifier(serialize(client_addr), client_port, serialize(server_addr), server_port); return find_stream(identifier); } -Stream& StreamFollower::find_stream(IPv6Address client_addr, uint16_t client_port, - IPv6Address server_addr, uint16_t server_port) { +Stream& StreamFollower::find_stream(const IPv6Address& client_addr, uint16_t client_port, + const IPv6Address& server_addr, uint16_t server_port) { stream_id identifier(serialize(client_addr), client_port, serialize(server_addr), server_port); return find_stream(identifier); @@ -195,7 +209,10 @@ void StreamFollower::cleanup_streams(const timestamp_type& now) { streams_type::iterator iter = streams_.begin(); while (iter != streams_.end()) { if (iter->second.last_seen() + stream_keep_alive_ <= now) { - // TODO: execute some callback here + // If we have a termination callback, execute it + if (on_stream_termination_) { + on_stream_termination_(iter->second, TIMEOUT); + } streams_.erase(iter++); } else { diff --git a/tests/src/tcp_ip.cpp b/tests/src/tcp_ip.cpp index cb8581a..f1fc76a 100644 --- a/tests/src/tcp_ip.cpp +++ b/tests/src/tcp_ip.cpp @@ -147,6 +147,8 @@ void FlowTest::run_test(uint32_t initial_seq, const ordering_info_type& chunks, } string flow_payload = merge_chunks(flow_payload_chunks); EXPECT_EQ(payload, string(flow_payload.begin(), flow_payload.end())); + EXPECT_EQ(0, flow.total_buffered_bytes()); + EXPECT_TRUE(flow.buffered_payload().empty()); } void FlowTest::run_test(uint32_t initial_seq, const ordering_info_type& chunks) { @@ -396,9 +398,13 @@ TEST_F(FlowTest, StreamFollower_TCPOptions) { TEST_F(FlowTest, StreamFollower_CleanupWorks) { using std::placeholders::_1; + bool timed_out = false; vector packets = three_way_handshake(29, 60, "1.2.3.4", 22, "4.3.2.1", 25); StreamFollower follower; follower.new_stream_callback(bind(&FlowTest::on_new_stream, this, _1)); + follower.stream_termination_callback([&](Stream&, StreamFollower::TerminationReason reason) { + timed_out = (reason == StreamFollower::TIMEOUT); + }); packets[2].rfind_pdu().src_addr("6.6.6.6"); auto base_time = duration_cast(system_clock::now().time_since_epoch()); Packet packet1(packets[0], base_time); @@ -414,7 +420,8 @@ TEST_F(FlowTest, StreamFollower_CleanupWorks) { EXPECT_THROW( follower.find_stream(IPv4Address("1.2.3.4"), 22, IPv4Address("4.3.2.1"), 25), stream_not_found - ); + ); + EXPECT_TRUE(timed_out); } TEST_F(FlowTest, StreamFollower_RSTClosesStream) {