diff --git a/include/tins/dns.h b/include/tins/dns.h index 4598e15..7b52553 100644 --- a/include/tins/dns.h +++ b/include/tins/dns.h @@ -236,11 +236,16 @@ public: * \param rclass The class of this record. * \param ttl The time-to-live of this record. */ - Resource(const std::string& dname, const std::string& data, - uint16_t type, uint16_t rclass, uint32_t ttl) - : dname_(dname), data_(data), type_(type), qclass_(rclass), ttl_(ttl) {} + Resource(const std::string& dname, + const std::string& data, + uint16_t type, + uint16_t rclass, + uint32_t ttl, + uint16_t preference = 0) + : dname_(dname), data_(data), type_(type), qclass_(rclass), + ttl_(ttl), preference_(preference) {} - Resource() : type_(), qclass_(), ttl_() {} + Resource() : type_(), qclass_(), ttl_(), preference_() {} /** * \brief Getter for the domain name field. @@ -248,27 +253,46 @@ public: * This returns the domain name for which this record * provides an answer. */ - const std::string& dname() const { return dname_; } + const std::string& dname() const { + return dname_; + } /** * Getter for the data field. */ - const std::string& data() const { return data_; } + const std::string& data() const { + return data_; + } /** * Getter for the query type field. */ - uint16_t type() const { return type_; } + uint16_t type() const { + return type_; + } /** * Getter for the query class field. */ - uint16_t query_class() const { return qclass_; } + uint16_t query_class() const { + return qclass_; + } /** - * Getter for the type field. + * Getter for the time-to-live field. */ - uint32_t ttl() const { return ttl_; } + uint32_t ttl() const { + return ttl_; + } + + /** + * \brief Getter for the preferece field. + * + * This field is only valid for MX resources. + */ + uint16_t preference() const { + return preference_; + } /** * Setter for the domain name field. @@ -313,10 +337,20 @@ public: void ttl(uint32_t data) { ttl_ = data; } + + /** + * \brief Setter for the preference field. + * + * This field is only valid for MX resources. + */ + void preference(uint16_t data) { + preference_ = data; + } private: std::string dname_, data_; uint16_t type_, qclass_; uint32_t ttl_; + uint16_t preference_; }; typedef std::list queries_type; @@ -711,11 +745,16 @@ private: typedef std::vector > sections_type; uint32_t compose_name(const uint8_t* ptr, char* out_ptr) const; - void convert_records(const uint8_t* ptr, const uint8_t* end, resources_type& res) const; + void convert_records(const uint8_t* ptr, + const uint8_t* end, + resources_type& res) const; void skip_to_section_end(Memory::InputMemoryStream& stream, const uint32_t num_records) const; void skip_to_dname_end(Memory::InputMemoryStream& stream) const; - void update_records(uint32_t& section_start, uint32_t num_records, uint32_t threshold, uint32_t offset); + void update_records(uint32_t& section_start, + uint32_t num_records, + uint32_t threshold, + uint32_t offset); uint8_t* update_dname(uint8_t* ptr, uint32_t threshold, uint32_t offset); static void inline_convert_v4(uint32_t value, char* output); static bool contains_dname(uint16_t type); diff --git a/src/dns.cpp b/src/dns.cpp index 8b15ffe..95e539a 100644 --- a/src/dns.cpp +++ b/src/dns.cpp @@ -203,7 +203,8 @@ void DNS::add_record(const Resource& resource, const sections_type& sections) { // will end up being inconsistent. IPv4Address v4_addr; IPv6Address v6_addr; - string buffer = encode_domain_name(resource.dname()), encoded_data; + string buffer = encode_domain_name(resource.dname()), + encoded_data; // By default the data size is the length of the data field. size_t data_size = resource.data().size(); if (resource.type() == A) { @@ -220,13 +221,17 @@ void DNS::add_record(const Resource& resource, const sections_type& sections) { } size_t offset = buffer.size() + sizeof(uint16_t) * 3 + sizeof(uint32_t) + data_size, threshold = sections.empty() ? records_data_.size() :* sections.front().first; - // Skip the preference field + // Take into account the MX preference field if (resource.type() == MX) { offset += sizeof(uint16_t); } for (size_t i = 0; i < sections.size(); ++i) { - update_records(*sections[i].first, sections[i].second, - static_cast(threshold), static_cast(offset)); + update_records( + *sections[i].first, + sections[i].second, + static_cast(threshold), + static_cast(offset) + ); } records_data_.insert( @@ -241,7 +246,7 @@ void DNS::add_record(const Resource& resource, const sections_type& sections) { stream.write_be(resource.ttl()); stream.write_be(data_size + (resource.type() == MX ? 2 : 0)); if (resource.type() == MX) { - stream.skip(sizeof(uint16_t)); + stream.write_be(resource.preference()); } if (resource.type() == A) { stream.write(v4_addr); @@ -380,15 +385,15 @@ void DNS::convert_records(const uint8_t* ptr, // Retrieve the record's domain name. stream.skip(compose_name(stream.pointer(), dname)); // Retrieve the following fields. - uint16_t type, qclass, data_size; + uint16_t type, qclass, data_size, preference = 0; uint32_t ttl; type = stream.read_be(); qclass = stream.read_be(); ttl = stream.read_be(); data_size = stream.read_be(); - // Skip the preference field if it's MX + // Read the preference field if it's MX if (type == MX) { - stream.skip(sizeof(uint16_t)); + preference = stream.read_be(); data_size -= sizeof(uint16_t); } if (!stream.can_read(data_size)) { @@ -431,7 +436,8 @@ void DNS::convert_records(const uint8_t* ptr, (used_small_buffer) ? small_addr_buf : addr, type, qclass, - ttl + ttl, + preference ) ); } diff --git a/tests/src/dns.cpp b/tests/src/dns.cpp index 8b3624e..22c0527 100644 --- a/tests/src/dns.cpp +++ b/tests/src/dns.cpp @@ -118,6 +118,7 @@ TEST_F(DNSTest, ConstructorFromBuffer2) { } DNS::resources_type resources = dns.answers(); + size_t resource_index = 0; for(DNS::resources_type::const_iterator it = resources.begin(); it != resources.end(); ++it) { EXPECT_EQ("google.com", it->dname()); EXPECT_EQ(DNS::MX, it->type()); @@ -130,6 +131,13 @@ TEST_F(DNSTest, ConstructorFromBuffer2) { it->data() == "alt5.aspmx.l.google.com" || it->data() == "aspmx.l.google.com" ); + if (resource_index == 0) { + EXPECT_EQ(50, it->preference()); + } + else if (resource_index == 1) { + EXPECT_EQ(40, it->preference()); + } + resource_index++; } // Add some stuff and see if something gets broken if(i == 0) { @@ -450,3 +458,18 @@ TEST_F(DNSTest, ItAintGonnaCorrupt) { EXPECT_EQ(it->query_class(), DNS::IN); } } + +TEST_F(DNSTest, MXPreferenceField) { + DNS dns1; + dns1.add_answer( + DNS::Resource("example.com", "mail.example.com", DNS::MX, DNS::IN, 0x762, 42) + ); + DNS::serialization_type buffer = dns1.serialize(); + DNS dns2(&buffer[0], buffer.size()); + DNS::resources_type answers = dns1.answers(); + ASSERT_EQ(1, answers.size()); + + const DNS::Resource& resource = *answers.begin(); + EXPECT_EQ(42, resource.preference()); + EXPECT_EQ("example.com", resource.dname()); +}