diff --git a/include/tins/exceptions.h b/include/tins/exceptions.h index 75d449f..05e0d38 100644 --- a/include/tins/exceptions.h +++ b/include/tins/exceptions.h @@ -294,6 +294,26 @@ public: } }; +/** + * \brief Exception thrown when a required callback for an object is not set + */ +class callback_not_set : public exception_base { +public: + const char* what() const throw() { + return "Callback not set"; + } +}; + +/** + * \brief Exception thrown when an invalid packet is provided to some function + */ +class invalid_packet : public exception_base { +public: + const char* what() const throw() { + return "Invalid packet"; + } +}; + namespace Crypto { namespace WPA2 { /** diff --git a/include/tins/tcp_ip/stream_follower.h b/include/tins/tcp_ip/stream_follower.h index da6732f..a648664 100644 --- a/include/tins/tcp_ip/stream_follower.h +++ b/include/tins/tcp_ip/stream_follower.h @@ -84,7 +84,8 @@ public: */ enum TerminationReason { TIMEOUT, ///< The stream was terminated due to a timeout - BUFFERED_DATA ///< The stream was terminated because it had too much buffered data + BUFFERED_DATA, ///< The stream was terminated because it had too much buffered data + SACKED_SEGMENTS ///< The stream was terminated because it had too many SACKed segments }; /** @@ -94,6 +95,45 @@ public: */ typedef std::function stream_termination_callback_type; + /** + * \brief Unique identifies a stream. + * + * This struct is used to track TCP streams. It keeps track of minimum and maximum + * addresses/ports in a stream to match packets coming from any of the 2 endpoints + * into the same object. + */ + struct stream_id { + /** + * The type used to store each endpoint's address + */ + typedef std::array address_type; + + /** + * Default constructor + */ + stream_id(); + + /** + * Constructs a stream_id + * + * \param client_addr Client's address + * \param client_port Port's port + * \param server_addr Server's address + * \param server_port Server's port + */ + stream_id(const address_type& client_addr, uint16_t client_port, + const address_type& server_addr, uint16_t server_port); + + address_type min_address; + address_type max_address; + uint16_t min_address_port; + uint16_t max_address_port; + + bool operator<(const stream_id& rhs) const; + + static size_t hash(const stream_id& id); + }; + /** * Default constructor */ @@ -177,33 +217,19 @@ public: Stream& find_stream(const IPv6Address& client_addr, uint16_t client_port, const IPv6Address& server_addr, uint16_t server_port); private: - typedef std::array address_type; typedef Stream::timestamp_type timestamp_type; static const size_t DEFAULT_MAX_BUFFERED_CHUNKS; + static const size_t DEFAULT_MAX_SACKED_INTERVALS; static const uint32_t DEFAULT_MAX_BUFFERED_BYTES; static const timestamp_type DEFAULT_KEEP_ALIVE; - struct stream_id { - stream_id(const address_type& client_addr, uint16_t client_port, - const address_type& server_addr, uint16_t server_port); - - address_type min_address; - address_type max_address; - uint16_t min_address_port; - uint16_t max_address_port; - - bool operator<(const stream_id& rhs) const; - - static size_t hash(const stream_id& id); - }; - typedef std::map streams_type; - stream_id make_stream_id(const PDU& packet); + static stream_id make_stream_id(const PDU& packet); Stream& find_stream(const stream_id& id); - static address_type serialize(IPv4Address address); - static address_type serialize(const IPv6Address& address); + static stream_id::address_type serialize(IPv4Address address); + static stream_id::address_type serialize(const IPv6Address& address); void process_packet(PDU& packet, const timestamp_type& ts); void cleanup_streams(const timestamp_type& now); diff --git a/src/tcp_ip/stream.cpp b/src/tcp_ip/stream.cpp index a2f1785..0e5c618 100644 --- a/src/tcp_ip/stream.cpp +++ b/src/tcp_ip/stream.cpp @@ -208,8 +208,7 @@ const Stream::timestamp_type& Stream::last_seen() const { Flow Stream::extract_client_flow(const PDU& packet) { const TCP* tcp = packet.find_pdu(); if (!tcp) { - // TODO: define proper exception - throw runtime_error("No TCP"); + throw invalid_packet(); } if (const IP* ip = packet.find_pdu()) { return Flow(ip->dst_addr(), tcp->dport(), tcp->seq()); @@ -218,16 +217,14 @@ Flow Stream::extract_client_flow(const PDU& packet) { return Flow(ip->dst_addr(), tcp->dport(), tcp->seq()); } else { - // TODO: define proper exception - throw runtime_error("No valid layer 3"); + throw invalid_packet(); } } Flow Stream::extract_server_flow(const PDU& packet) { const TCP* tcp = packet.find_pdu(); if (!tcp) { - // TODO: define proper exception - throw runtime_error("No TCP"); + throw invalid_packet(); } if (const IP* ip = packet.find_pdu()) { return Flow(ip->src_addr(), tcp->sport(), tcp->ack_seq()); @@ -236,8 +233,7 @@ Flow Stream::extract_server_flow(const PDU& packet) { return Flow(ip->src_addr(), tcp->sport(), tcp->ack_seq()); } else { - // TODO: define proper exception - throw runtime_error("No valid layer 3"); + throw invalid_packet(); } } diff --git a/src/tcp_ip/stream_follower.cpp b/src/tcp_ip/stream_follower.cpp index 54ea341..d0a2d38 100644 --- a/src/tcp_ip/stream_follower.cpp +++ b/src/tcp_ip/stream_follower.cpp @@ -62,6 +62,7 @@ namespace Tins { namespace TCPIP { const size_t StreamFollower::DEFAULT_MAX_BUFFERED_CHUNKS = 512; +const size_t StreamFollower::DEFAULT_MAX_SACKED_INTERVALS = 1024; const uint32_t StreamFollower::DEFAULT_MAX_BUFFERED_BYTES = 3 * 1024 * 1024; // 3MB const StreamFollower::timestamp_type StreamFollower::DEFAULT_KEEP_ALIVE = minutes(5); @@ -83,24 +84,26 @@ void StreamFollower::process_packet(Packet& packet) { } void StreamFollower::process_packet(PDU& packet, const timestamp_type& ts) { + const TCP* tcp = packet.find_pdu(); + if (!tcp) { + return; + } stream_id identifier = make_stream_id(packet); streams_type::iterator iter = streams_.find(identifier); bool process = true; if (iter == streams_.end()) { - const TCP& tcp = packet.rfind_pdu(); // Start tracking if they're either SYNs or they contain data (attach // to an already running flow). - if (tcp.flags() == TCP::SYN || (attach_to_flows_ && tcp.find_pdu() != 0)) { + if (tcp->flags() == TCP::SYN || (attach_to_flows_ && tcp->find_pdu() != 0)) { iter = streams_.insert(make_pair(identifier, Stream(packet, ts))).first; iter->second.setup_flows_callbacks(); if (on_new_connection_) { on_new_connection_(iter->second); } else { - // TODO: use proper exception - throw runtime_error("No new connection callback set"); + throw callback_not_set(); } - if (tcp.flags() == TCP::SYN) { + if (tcp->flags() == TCP::SYN) { process = false; } else { @@ -118,16 +121,27 @@ void StreamFollower::process_packet(PDU& packet, const timestamp_type& ts) { if (process) { Stream& stream = iter->second; stream.process_packet(packet, ts); + // Check for different potential termination size_t total_chunks = stream.client_flow().buffered_payload().size() + stream.server_flow().buffered_payload().size(); uint32_t total_buffered_bytes = stream.client_flow().total_buffered_bytes() + stream.server_flow().total_buffered_bytes(); bool terminate_stream = total_chunks > max_buffered_chunks_ || total_buffered_bytes > max_buffered_bytes_; + TerminationReason reason = BUFFERED_DATA; + #ifdef HAVE_ACK_TRACKER + if (!terminate_stream) { + uint32_t count = 0; + count += stream.client_flow().ack_tracker().acked_intervals().iterative_size(); + count += stream.server_flow().ack_tracker().acked_intervals().iterative_size(); + terminate_stream = count > DEFAULT_MAX_SACKED_INTERVALS; + reason = SACKED_SEGMENTS; + } + #endif // HAVE_ACK_TRACKER if (stream.is_finished() || terminate_stream) { // If we're terminating the stream, execute the termination callback if (terminate_stream && on_stream_termination_) { - on_stream_termination_(stream, BUFFERED_DATA); + on_stream_termination_(stream, reason); } streams_.erase(iter); } @@ -162,8 +176,7 @@ Stream& StreamFollower::find_stream(const IPv6Address& client_addr, uint16_t cli StreamFollower::stream_id StreamFollower::make_stream_id(const PDU& packet) { const TCP* tcp = packet.find_pdu(); if (!tcp) { - // TODO: define proper exception - throw runtime_error("No TCP"); + throw invalid_packet(); } if (const IP* ip = packet.find_pdu()) { return stream_id(serialize(ip->src_addr()), tcp->sport(), @@ -174,8 +187,7 @@ StreamFollower::stream_id StreamFollower::make_stream_id(const PDU& packet) { serialize(ip->dst_addr()), tcp->dport()); } else { - // TODO: define proper exception - throw runtime_error("No layer 3"); + throw invalid_packet(); } } @@ -189,16 +201,16 @@ Stream& StreamFollower::find_stream(const stream_id& id) { } } -StreamFollower::address_type StreamFollower::serialize(IPv4Address address) { - address_type addr; +StreamFollower::stream_id::address_type StreamFollower::serialize(IPv4Address address) { + stream_id::address_type addr; OutputMemoryStream output(addr.data(), addr.size()); addr.fill(0); output.write(address); return addr; } -StreamFollower::address_type StreamFollower::serialize(const IPv6Address& address) { - address_type addr; +StreamFollower::stream_id::address_type StreamFollower::serialize(const IPv6Address& address) { + stream_id::address_type addr; OutputMemoryStream output(addr.data(), addr.size()); addr.fill(0); output.write(address); @@ -224,6 +236,12 @@ void StreamFollower::cleanup_streams(const timestamp_type& now) { // stream_id +StreamFollower::stream_id::stream_id() +: min_address_port(0), max_address_port(0) { + min_address.fill(0); + max_address.fill(0); +} + StreamFollower::stream_id::stream_id(const address_type& client_addr, uint16_t client_port, const address_type& server_addr,