1
0
mirror of https://github.com/mfontanini/libtins synced 2026-01-23 02:35:57 +01:00

Throw proper exceptions

This commit is contained in:
Matias Fontanini
2016-02-14 16:51:10 -08:00
parent 4123764a48
commit eb1c43d293
4 changed files with 101 additions and 41 deletions

View File

@@ -294,6 +294,26 @@ public:
} }
}; };
/**
* \brief Exception thrown when a required callback for an object is not set
*/
class callback_not_set : public exception_base {
public:
const char* what() const throw() {
return "Callback not set";
}
};
/**
* \brief Exception thrown when an invalid packet is provided to some function
*/
class invalid_packet : public exception_base {
public:
const char* what() const throw() {
return "Invalid packet";
}
};
namespace Crypto { namespace Crypto {
namespace WPA2 { namespace WPA2 {
/** /**

View File

@@ -84,7 +84,8 @@ public:
*/ */
enum TerminationReason { enum TerminationReason {
TIMEOUT, ///< The stream was terminated due to a timeout TIMEOUT, ///< The stream was terminated due to a timeout
BUFFERED_DATA ///< The stream was terminated because it had too much buffered data BUFFERED_DATA, ///< The stream was terminated because it had too much buffered data
SACKED_SEGMENTS ///< The stream was terminated because it had too many SACKed segments
}; };
/** /**
@@ -94,6 +95,45 @@ public:
*/ */
typedef std::function<void(Stream&, TerminationReason)> stream_termination_callback_type; typedef std::function<void(Stream&, TerminationReason)> stream_termination_callback_type;
/**
* \brief Unique identifies a stream.
*
* This struct is used to track TCP streams. It keeps track of minimum and maximum
* addresses/ports in a stream to match packets coming from any of the 2 endpoints
* into the same object.
*/
struct stream_id {
/**
* The type used to store each endpoint's address
*/
typedef std::array<uint8_t, 16> address_type;
/**
* Default constructor
*/
stream_id();
/**
* Constructs a stream_id
*
* \param client_addr Client's address
* \param client_port Port's port
* \param server_addr Server's address
* \param server_port Server's port
*/
stream_id(const address_type& client_addr, uint16_t client_port,
const address_type& server_addr, uint16_t server_port);
address_type min_address;
address_type max_address;
uint16_t min_address_port;
uint16_t max_address_port;
bool operator<(const stream_id& rhs) const;
static size_t hash(const stream_id& id);
};
/** /**
* Default constructor * Default constructor
*/ */
@@ -177,33 +217,19 @@ public:
Stream& find_stream(const IPv6Address& client_addr, uint16_t client_port, Stream& find_stream(const IPv6Address& client_addr, uint16_t client_port,
const IPv6Address& server_addr, uint16_t server_port); const IPv6Address& server_addr, uint16_t server_port);
private: private:
typedef std::array<uint8_t, 16> address_type;
typedef Stream::timestamp_type timestamp_type; typedef Stream::timestamp_type timestamp_type;
static const size_t DEFAULT_MAX_BUFFERED_CHUNKS; static const size_t DEFAULT_MAX_BUFFERED_CHUNKS;
static const size_t DEFAULT_MAX_SACKED_INTERVALS;
static const uint32_t DEFAULT_MAX_BUFFERED_BYTES; static const uint32_t DEFAULT_MAX_BUFFERED_BYTES;
static const timestamp_type DEFAULT_KEEP_ALIVE; static const timestamp_type DEFAULT_KEEP_ALIVE;
struct stream_id {
stream_id(const address_type& client_addr, uint16_t client_port,
const address_type& server_addr, uint16_t server_port);
address_type min_address;
address_type max_address;
uint16_t min_address_port;
uint16_t max_address_port;
bool operator<(const stream_id& rhs) const;
static size_t hash(const stream_id& id);
};
typedef std::map<stream_id, Stream> streams_type; typedef std::map<stream_id, Stream> streams_type;
stream_id make_stream_id(const PDU& packet); static stream_id make_stream_id(const PDU& packet);
Stream& find_stream(const stream_id& id); Stream& find_stream(const stream_id& id);
static address_type serialize(IPv4Address address); static stream_id::address_type serialize(IPv4Address address);
static address_type serialize(const IPv6Address& address); static stream_id::address_type serialize(const IPv6Address& address);
void process_packet(PDU& packet, const timestamp_type& ts); void process_packet(PDU& packet, const timestamp_type& ts);
void cleanup_streams(const timestamp_type& now); void cleanup_streams(const timestamp_type& now);

View File

@@ -208,8 +208,7 @@ const Stream::timestamp_type& Stream::last_seen() const {
Flow Stream::extract_client_flow(const PDU& packet) { Flow Stream::extract_client_flow(const PDU& packet) {
const TCP* tcp = packet.find_pdu<TCP>(); const TCP* tcp = packet.find_pdu<TCP>();
if (!tcp) { if (!tcp) {
// TODO: define proper exception throw invalid_packet();
throw runtime_error("No TCP");
} }
if (const IP* ip = packet.find_pdu<IP>()) { if (const IP* ip = packet.find_pdu<IP>()) {
return Flow(ip->dst_addr(), tcp->dport(), tcp->seq()); return Flow(ip->dst_addr(), tcp->dport(), tcp->seq());
@@ -218,16 +217,14 @@ Flow Stream::extract_client_flow(const PDU& packet) {
return Flow(ip->dst_addr(), tcp->dport(), tcp->seq()); return Flow(ip->dst_addr(), tcp->dport(), tcp->seq());
} }
else { else {
// TODO: define proper exception throw invalid_packet();
throw runtime_error("No valid layer 3");
} }
} }
Flow Stream::extract_server_flow(const PDU& packet) { Flow Stream::extract_server_flow(const PDU& packet) {
const TCP* tcp = packet.find_pdu<TCP>(); const TCP* tcp = packet.find_pdu<TCP>();
if (!tcp) { if (!tcp) {
// TODO: define proper exception throw invalid_packet();
throw runtime_error("No TCP");
} }
if (const IP* ip = packet.find_pdu<IP>()) { if (const IP* ip = packet.find_pdu<IP>()) {
return Flow(ip->src_addr(), tcp->sport(), tcp->ack_seq()); return Flow(ip->src_addr(), tcp->sport(), tcp->ack_seq());
@@ -236,8 +233,7 @@ Flow Stream::extract_server_flow(const PDU& packet) {
return Flow(ip->src_addr(), tcp->sport(), tcp->ack_seq()); return Flow(ip->src_addr(), tcp->sport(), tcp->ack_seq());
} }
else { else {
// TODO: define proper exception throw invalid_packet();
throw runtime_error("No valid layer 3");
} }
} }

View File

@@ -62,6 +62,7 @@ namespace Tins {
namespace TCPIP { namespace TCPIP {
const size_t StreamFollower::DEFAULT_MAX_BUFFERED_CHUNKS = 512; const size_t StreamFollower::DEFAULT_MAX_BUFFERED_CHUNKS = 512;
const size_t StreamFollower::DEFAULT_MAX_SACKED_INTERVALS = 1024;
const uint32_t StreamFollower::DEFAULT_MAX_BUFFERED_BYTES = 3 * 1024 * 1024; // 3MB const uint32_t StreamFollower::DEFAULT_MAX_BUFFERED_BYTES = 3 * 1024 * 1024; // 3MB
const StreamFollower::timestamp_type StreamFollower::DEFAULT_KEEP_ALIVE = minutes(5); const StreamFollower::timestamp_type StreamFollower::DEFAULT_KEEP_ALIVE = minutes(5);
@@ -83,24 +84,26 @@ void StreamFollower::process_packet(Packet& packet) {
} }
void StreamFollower::process_packet(PDU& packet, const timestamp_type& ts) { void StreamFollower::process_packet(PDU& packet, const timestamp_type& ts) {
const TCP* tcp = packet.find_pdu<TCP>();
if (!tcp) {
return;
}
stream_id identifier = make_stream_id(packet); stream_id identifier = make_stream_id(packet);
streams_type::iterator iter = streams_.find(identifier); streams_type::iterator iter = streams_.find(identifier);
bool process = true; bool process = true;
if (iter == streams_.end()) { if (iter == streams_.end()) {
const TCP& tcp = packet.rfind_pdu<TCP>();
// Start tracking if they're either SYNs or they contain data (attach // Start tracking if they're either SYNs or they contain data (attach
// to an already running flow). // to an already running flow).
if (tcp.flags() == TCP::SYN || (attach_to_flows_ && tcp.find_pdu<RawPDU>() != 0)) { if (tcp->flags() == TCP::SYN || (attach_to_flows_ && tcp->find_pdu<RawPDU>() != 0)) {
iter = streams_.insert(make_pair(identifier, Stream(packet, ts))).first; iter = streams_.insert(make_pair(identifier, Stream(packet, ts))).first;
iter->second.setup_flows_callbacks(); iter->second.setup_flows_callbacks();
if (on_new_connection_) { if (on_new_connection_) {
on_new_connection_(iter->second); on_new_connection_(iter->second);
} }
else { else {
// TODO: use proper exception throw callback_not_set();
throw runtime_error("No new connection callback set");
} }
if (tcp.flags() == TCP::SYN) { if (tcp->flags() == TCP::SYN) {
process = false; process = false;
} }
else { else {
@@ -118,16 +121,27 @@ void StreamFollower::process_packet(PDU& packet, const timestamp_type& ts) {
if (process) { if (process) {
Stream& stream = iter->second; Stream& stream = iter->second;
stream.process_packet(packet, ts); stream.process_packet(packet, ts);
// Check for different potential termination
size_t total_chunks = stream.client_flow().buffered_payload().size() + size_t total_chunks = stream.client_flow().buffered_payload().size() +
stream.server_flow().buffered_payload().size(); stream.server_flow().buffered_payload().size();
uint32_t total_buffered_bytes = stream.client_flow().total_buffered_bytes() + uint32_t total_buffered_bytes = stream.client_flow().total_buffered_bytes() +
stream.server_flow().total_buffered_bytes(); stream.server_flow().total_buffered_bytes();
bool terminate_stream = total_chunks > max_buffered_chunks_ || bool terminate_stream = total_chunks > max_buffered_chunks_ ||
total_buffered_bytes > max_buffered_bytes_; total_buffered_bytes > max_buffered_bytes_;
TerminationReason reason = BUFFERED_DATA;
#ifdef HAVE_ACK_TRACKER
if (!terminate_stream) {
uint32_t count = 0;
count += stream.client_flow().ack_tracker().acked_intervals().iterative_size();
count += stream.server_flow().ack_tracker().acked_intervals().iterative_size();
terminate_stream = count > DEFAULT_MAX_SACKED_INTERVALS;
reason = SACKED_SEGMENTS;
}
#endif // HAVE_ACK_TRACKER
if (stream.is_finished() || terminate_stream) { if (stream.is_finished() || terminate_stream) {
// If we're terminating the stream, execute the termination callback // If we're terminating the stream, execute the termination callback
if (terminate_stream && on_stream_termination_) { if (terminate_stream && on_stream_termination_) {
on_stream_termination_(stream, BUFFERED_DATA); on_stream_termination_(stream, reason);
} }
streams_.erase(iter); streams_.erase(iter);
} }
@@ -162,8 +176,7 @@ Stream& StreamFollower::find_stream(const IPv6Address& client_addr, uint16_t cli
StreamFollower::stream_id StreamFollower::make_stream_id(const PDU& packet) { StreamFollower::stream_id StreamFollower::make_stream_id(const PDU& packet) {
const TCP* tcp = packet.find_pdu<TCP>(); const TCP* tcp = packet.find_pdu<TCP>();
if (!tcp) { if (!tcp) {
// TODO: define proper exception throw invalid_packet();
throw runtime_error("No TCP");
} }
if (const IP* ip = packet.find_pdu<IP>()) { if (const IP* ip = packet.find_pdu<IP>()) {
return stream_id(serialize(ip->src_addr()), tcp->sport(), return stream_id(serialize(ip->src_addr()), tcp->sport(),
@@ -174,8 +187,7 @@ StreamFollower::stream_id StreamFollower::make_stream_id(const PDU& packet) {
serialize(ip->dst_addr()), tcp->dport()); serialize(ip->dst_addr()), tcp->dport());
} }
else { else {
// TODO: define proper exception throw invalid_packet();
throw runtime_error("No layer 3");
} }
} }
@@ -189,16 +201,16 @@ Stream& StreamFollower::find_stream(const stream_id& id) {
} }
} }
StreamFollower::address_type StreamFollower::serialize(IPv4Address address) { StreamFollower::stream_id::address_type StreamFollower::serialize(IPv4Address address) {
address_type addr; stream_id::address_type addr;
OutputMemoryStream output(addr.data(), addr.size()); OutputMemoryStream output(addr.data(), addr.size());
addr.fill(0); addr.fill(0);
output.write(address); output.write(address);
return addr; return addr;
} }
StreamFollower::address_type StreamFollower::serialize(const IPv6Address& address) { StreamFollower::stream_id::address_type StreamFollower::serialize(const IPv6Address& address) {
address_type addr; stream_id::address_type addr;
OutputMemoryStream output(addr.data(), addr.size()); OutputMemoryStream output(addr.data(), addr.size());
addr.fill(0); addr.fill(0);
output.write(address); output.write(address);
@@ -224,6 +236,12 @@ void StreamFollower::cleanup_streams(const timestamp_type& now) {
// stream_id // stream_id
StreamFollower::stream_id::stream_id()
: min_address_port(0), max_address_port(0) {
min_address.fill(0);
max_address.fill(0);
}
StreamFollower::stream_id::stream_id(const address_type& client_addr, StreamFollower::stream_id::stream_id(const address_type& client_addr,
uint16_t client_port, uint16_t client_port,
const address_type& server_addr, const address_type& server_addr,