diff --git a/src/tcp_ip/ack_tracker.cpp b/src/tcp_ip/ack_tracker.cpp index 9991e7c..9d3e326 100644 --- a/src/tcp_ip/ack_tracker.cpp +++ b/src/tcp_ip/ack_tracker.cpp @@ -46,6 +46,15 @@ using Tins::Internals::seq_compare; namespace Tins { namespace TCPIP { +uint32_t interval_start(const AckedRange::interval_type& interval) { + if (interval.bounds() == interval_bounds::left_open()) { + return interval.lower() + 1; + } + else { + return interval.lower(); + } +} + uint32_t interval_end(const AckedRange::interval_type& interval) { if (interval.bounds() == interval_bounds::right_open()) { return interval.upper() - 1; @@ -123,10 +132,21 @@ void AckTracker::process_sack(const vector& sack) { // Left edge must be lower than right edge if (seq_compare(sack[i - 1], sack[i]) < 0) { AckedRange range(sack[i - 1], sack[i] - 1); - // If this range starts after our current ack number - if (seq_compare(range.first(), ack_number_) > 0) { + // If this range ends after our current ack number + if (seq_compare(range.last(), ack_number_) > 0) { while (range.has_next()) { - acked_intervals_.insert(range.next()); + AckedRange::interval_type next = range.next(); + uint32_t start = interval_start(next); + if (seq_compare(start, ack_number_) <= 0) { + // If this interval starts before or at our ACK number + // then we need to update our ACK number to the end of + // this interval + ack_number_ = interval_end(next); + } + else { + // Otherwise, push the interval into the ACK set + acked_intervals_.insert(next); + } } } } diff --git a/tests/src/tcp_ip.cpp b/tests/src/tcp_ip.cpp index be53c20..4a4b326 100644 --- a/tests/src/tcp_ip.cpp +++ b/tests/src/tcp_ip.cpp @@ -663,15 +663,19 @@ TEST_F(AckTrackerTest, AckingTcp_Sack2) { EXPECT_TRUE(tracker.is_segment_acked(maximum - 2, 1)); EXPECT_TRUE(tracker.is_segment_acked(2, 3)); EXPECT_FALSE(tracker.is_segment_acked(maximum - 10, 10)); + EXPECT_EQ(maximum - 10, tracker.ack_number()); tracker.process_packet(make_tcp_ack(maximum - 2)); EXPECT_EQ(1U + 10U, tracker.acked_intervals().size()); + EXPECT_EQ(maximum - 2, tracker.ack_number()); tracker.process_packet(make_tcp_ack(5)); EXPECT_EQ(4U, tracker.acked_intervals().size()); + EXPECT_EQ(5, tracker.ack_number()); tracker.process_packet(make_tcp_ack(15)); EXPECT_EQ(0U, tracker.acked_intervals().size()); + EXPECT_EQ(15, tracker.ack_number()); } TEST_F(AckTrackerTest, AckingTcp_Sack3) { @@ -682,9 +686,27 @@ TEST_F(AckTrackerTest, AckingTcp_Sack3) { make_pair(maximum - 3, 5) )); EXPECT_EQ(9U, tracker.acked_intervals().size()); + EXPECT_EQ(maximum - 10, tracker.ack_number()); tracker.process_packet(make_tcp_ack(maximum)); EXPECT_EQ(5U, tracker.acked_intervals().size()); + EXPECT_EQ(maximum, tracker.ack_number()); +} + +TEST_F(AckTrackerTest, AckingTcp_SackOutOfOrder1) { + AckTracker tracker(0, true); + tracker.process_packet(make_tcp_ack(10)); + tracker.process_packet(make_tcp_ack(0, make_pair(9, 12))); + EXPECT_EQ(0, tracker.acked_intervals().size()); + EXPECT_EQ(11, tracker.ack_number()); +} + +TEST_F(AckTrackerTest, AckingTcp_SackOutOfOrder2) { + AckTracker tracker(0, true); + tracker.process_packet(make_tcp_ack(10)); + tracker.process_packet(make_tcp_ack(0, make_pair(10, 12))); + EXPECT_EQ(0, tracker.acked_intervals().size()); + EXPECT_EQ(11, tracker.ack_number()); } TEST_F(FlowTest, AckNumbersAreCorrect) {