From d6ba6a919932cb0d4648b4a403f2cd385dabfd39 Mon Sep 17 00:00:00 2001 From: inga-lovinde <52715130+inga-lovinde@users.noreply.github.com> Date: Sat, 5 Dec 2020 00:29:15 +0100 Subject: [PATCH] updated readme; implemented MD5 computation using SIMD --- Cargo.toml | 7 +- README.md | 87 ++++++++++++++++- src/anagram_logger.rs | 6 +- src/dictionary_builder.rs | 37 +++++++- src/hash_computer.rs | 184 +++++++++++++++++++++++++++++++++++- src/lib.rs | 9 ++ src/main.rs | 25 +++-- src/permutations_cache.rs | 1 + src/vector_alphabet.rs | 2 +- tests/hash_computer_test.rs | 57 +++++++++++ 10 files changed, 393 insertions(+), 22 deletions(-) create mode 100644 src/lib.rs create mode 100644 src/permutations_cache.rs create mode 100644 tests/hash_computer_test.rs diff --git a/Cargo.toml b/Cargo.toml index 18b4e8c..cc59c88 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "hello_cargo" +name = "trustpilot_challenge_rust" version = "0.1.0" authors = ["inga-lovinde <52715130+inga-lovinde@users.noreply.github.com>"] edition = "2018" @@ -7,4 +7,7 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -packed_simd = { version = "0.3.4", package = "packed_simd_2" } +crunchy = "0.2.2" +packed_simd = { version = "0.3.4", package = "packed_simd_2", features = ["into_bits"] } +permutohedron = "0.2.4" +rayon = "1.5.0" diff --git a/README.md b/README.md index a8c39a9..ce480e8 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# TrustPilotChallengeRust +# TrustPilotChallengeRust TrustPilot had this challenge several years ago (http://followthewhiterabbit.trustpilot.com/) @@ -26,7 +26,12 @@ and does not yet do actual MD5 calculation. ## Algorithm description -Notably this solution does not involve string concatenation; strings are only concatenated for debugging purposes. +Notably this solution does not involve string concatenation; +strings are only concatenated for debugging purposes. +It also computes eight MD5 hashes at a time *per thread* +(that is, 128 MD5 hashes at once on a modern 8-core CPU), +with some further optimizations which further shave off +several percents from MD5 computation time. We could split the problem into three parts: finding all anagrams (up to words reordering and replacing some of the words with their single-word anagrams), @@ -70,6 +75,84 @@ could be further optimized, resulting in the following algorithm: * W might be one element of a target subset, and the remaining elements could be found by solving the task 2 for N-1, P-W and position of W in the list of vectors. +### Finding all anagrams, pt. 2 + +In the previous step, we just found all unique tuples of vectors with vectors ordered by norm decreasing +such that they give the required vector. +Now we need to convert these back to phrases. + +If for every vector there was only one word which produces that vector, +and if all the vectors in a tuple were different, +we could just look at all their permutations and get n! solutions from a single tuple. + +But a tuple can contain several copies of one vector, +and there could be several different words corresponding to one vector. +Computing all possible permutations would result in duplicate solutions +and too much unneccessary work. + +So we could: + +1. Substitute all possible word values for every vector, getting several (ordered) word solutions; +2. Apply all possible permutations to them such that, if vectors k and k+1 were the same in the vector solution, + word k should go before word k+1 in the word solution + (because the solution where word k goes after word k+1 is already obtained by a different substitution on step 1). + +Every string shorter than 32 bytes could be represented as a single u8x32 AVX2 register +(with the remaining bytes filled with zeroes). + +Concatenating strings could be as simple as XORing the vectors, shifted appropriately. + +For example, to create `"a b "` string we would need to compute `"a " xor "␀␀b "`, +which is done in a single cycle on a modern CPU, provided that we have both vectors ready. +This is as opposed to concatenating strings which would require allocating a new string on the heap +and copying the data. + +So we could just store all of the original words as such a vectors for all possible offsets +(along with trailing spaces), and when we need to compute a phrase consisting of the word x and the word y, +just do something along the lines of `get_register(x, 0) xor get_register(y, x.length)` + +### Computing hashes + +MD5 works on input messages in 64 byte blocks; for short strings (shorter than 55 bytes) +it only uses a single blocks: 0x80 byte is appended to the message, then it is padded to 56 bytes with zeroes, +and then the total length of the string in bits is appended as 64-bit number. + +So short phrases (shorter than 31 bytes) could be represented with two AVX2 registers: +one containing the phrase itself with the trailing 0x80, and another containing 24 zeroes +and 64-bit length of the phrase in bits (which is the number of non-space bytes +plus the number of words, times 8). + +For its internal state, MD5 has four 32-bit variables (u32). +This means that with AVX2, we can use the same operations on 256-bit registers +(u32x8) and compute eight hashes at the same time in a single thread. + +MD5 breaks input chunks into 16 u32 words (and for short phrases chunks 8-14 are always zero), +so our algorithm could receive 8x256-bit values and the phrase length, +rearrange these into 9 256-bit values (8 obtained by transposing the original 8 as 8x8 matrix of u32, +and ninth being 8 copies of the phrase length in bits), +and then implement MD5 algorithms using these 9 values as input words 0..7, 15 +(substituting 0 as input words 8..14). + +That way, MD5 performance would be increased 8x compared to the ordinary library function +which does not use SIMD. + +As a minor additional optimization, we could only compute the first u32 part of the MD5 hash +(because we don't need to compute entire hashes for all possible anagrams, +we only need to find anagrams which match the requested hashes. +That way, we'll save some unneeded steps in MD5 computation, +and we also won't have to convert hashes back to separate variables: +we could just compare u32x8 holding the first parts of hashes for eight different anagrams +with u32x8 holding eight copies of the first part of the requested hash. +That way, we'll only have one comparison instead of eight, +at the cost of rare false positives which occur on average with 1/2^29 probability +(1/2^32 chance that a random u32 matches the requested u32, for every of the eight anagrams). +If there is such a semi-match (that is, one of the eight anagrams produces a hash +with first 32 bits matching first 32 bits of the requested hash), we could just +compute MD5 for every of the eight anagrams in the ordinary way and +to compare the whole resulting hashes with the requested ones; +as this is extremely rare (once every 1/29th calls to SIMD MD5 function), +it will not severely affect performance. + ## How to run How to run to solve the original task for three-word anagrams: diff --git a/src/anagram_logger.rs b/src/anagram_logger.rs index d1d32ca..ee624e2 100644 --- a/src/anagram_logger.rs +++ b/src/anagram_logger.rs @@ -5,11 +5,11 @@ pub fn get_anagram_view(anagram: Vec, dictionary: &dictionary_builder::Di .map(|&index| { let word_options = &dictionary.words[index]; if word_options.len() == 1 { - word_options[0].clone() + word_options[0].word.clone() } else { - format!("[{}]", word_options.join(",")) + format!("[{}]", word_options.iter().map(|word_info| word_info.word.clone()).collect::>().join(",")) } }) .collect::>() .join(" ") -} \ No newline at end of file +} diff --git a/src/dictionary_builder.rs b/src/dictionary_builder.rs index 993f3e0..ec403f7 100644 --- a/src/dictionary_builder.rs +++ b/src/dictionary_builder.rs @@ -1,13 +1,44 @@ use std::collections::HashMap; +use packed_simd; use crate::vector_alphabet; +pub struct WordInfo { + simd_words: [packed_simd::u8x32; 32], + pub length: usize, + pub word: String, +} + +impl WordInfo { + fn new(word: String) -> WordInfo { + let mut byte_array: [u8; 64] = [0; 64]; + let bytes = word.as_bytes(); + let length = bytes.len(); + byte_array[32 + length] = b' '; + for i in 0..length { + byte_array[32 + i] = bytes[i]; + } + + let simd_word_zero: packed_simd::u8x32 = packed_simd::u8x32::from_slice_unaligned(&[0; 32]); + let mut simd_words: [packed_simd::u8x32; 32] = [simd_word_zero; 32]; + for i in 0..31 { + simd_words[i] = packed_simd::u8x32::from_slice_unaligned(&byte_array[32-i..64-i]); + } + + WordInfo { + simd_words, + length, + word, + } + } +} + pub struct Dictionary { pub phrase_vector: vector_alphabet::Vector, pub vectors: Vec, - pub words: Vec>, + pub words: Vec>, } -pub fn build_dictionary(phrase: &String, unique_words: &[String]) -> Dictionary { +pub fn build_dictionary(phrase: &String, unique_words: Vec) -> Dictionary { let alphabet = vector_alphabet::Alphabet::new(phrase).unwrap(); let phrase_with_metadata = alphabet.vectorize(phrase).unwrap(); @@ -35,7 +66,7 @@ pub fn build_dictionary(phrase: &String, unique_words: &[String]) -> Dictionary let mut words_by_vectors: HashMap<_, _> = HashMap::new(); for (word, vector_with_metadata) in words_with_vectors { let (_, words_for_vector) = words_by_vectors.entry(vector_with_metadata.key).or_insert((vector_with_metadata.vector, vec![])); - words_for_vector.push(word.clone()); + words_for_vector.push(WordInfo::new(word)); } let mut words_by_vectors: Vec<_> = words_by_vectors.into_values().collect(); diff --git a/src/hash_computer.rs b/src/hash_computer.rs index 029e3bd..ddc291f 100644 --- a/src/hash_computer.rs +++ b/src/hash_computer.rs @@ -1 +1,183 @@ -pub const MAX_PHRASE_LENGTH: usize = 27; \ No newline at end of file +use packed_simd::FromBits; +use packed_simd::u32x8; +use packed_simd::u8x32; + +pub const MAX_PHRASE_LENGTH: usize = 31; + +#[allow(unused_assignments)] +pub fn compute_hashes(messages: [u8x32; 8], messages_length: usize) -> [u32; 8] { + let mut a: u32x8 = u32x8::splat(0x67452301); + let mut b: u32x8 = u32x8::splat(0xefcdab89); + let mut c: u32x8 = u32x8::splat(0x98badcfe); + let mut d: u32x8 = u32x8::splat(0x10325476); + + let trailer = u8x32::splat(0).replace(messages_length, b' ' ^ 0x80); + let mut messages_bytes: [u32; 64] = [0; 64]; + { + macro_rules! write_bytes { + ($i: expr) => { + u32x8::from_bits(messages[$i] ^ trailer).write_to_slice_unaligned(&mut messages_bytes[($i*8)..]) + } + } + write_bytes!(0); + write_bytes!(1); + write_bytes!(2); + write_bytes!(3); + write_bytes!(4); + write_bytes!(5); + write_bytes!(6); + write_bytes!(7); + } + + macro_rules! get_m_value { + ($i: expr) => { + u32x8::new( + messages_bytes[0*8 + $i], + messages_bytes[1*8 + $i], + messages_bytes[2*8 + $i], + messages_bytes[3*8 + $i], + messages_bytes[4*8 + $i], + messages_bytes[5*8 + $i], + messages_bytes[6*8 + $i], + messages_bytes[7*8 + $i], + ) + }; + } + + let m0: u32x8 = get_m_value!(0); + let m1: u32x8 = get_m_value!(1); + let m2: u32x8 = get_m_value!(2); + let m3: u32x8 = get_m_value!(3); + let m4: u32x8 = get_m_value!(4); + let m5: u32x8 = get_m_value!(5); + let m6: u32x8 = get_m_value!(6); + let m7: u32x8 = get_m_value!(7); + let m14: u32x8 = u32x8::splat((messages_length as u32) * 8); + + macro_rules! lrot { + ($f: expr, $s: expr) => (($f << $s) | ($f >> (32-$s))); + } + + macro_rules! blend { + ($mask: expr, $a: expr, $b: expr) => { + // andnot (_mm256_andnot_si256) is not implemented in packed_simd + ($a & $mask) | ($b & !$mask) + } + } + + macro_rules! step { + ($f: expr, $s: expr, $k: expr, $m: expr) => { + let f = $f + a + u32x8::splat($k) + $m; + a = d; + d = c; + c = b; + b = b + lrot!(f, $s); + }; + ($f: expr, $s: expr, $k: expr) => { + let f = $f + a + u32x8::splat($k); + a = d; + d = c; + c = b; + b = b + lrot!(f, $s); + }; + } + + { + macro_rules! step_1 { + () => (blend!(b, c, d)); + } + + step!(step_1!(), 7, 0xd76aa478, m0); + step!(step_1!(), 12, 0xe8c7b756, m1); + step!(step_1!(), 17, 0x242070db, m2); + step!(step_1!(), 22, 0xc1bdceee, m3); + step!(step_1!(), 7, 0xf57c0faf, m4); + step!(step_1!(), 12, 0x4787c62a, m5); + step!(step_1!(), 17, 0xa8304613, m6); + step!(step_1!(), 22, 0xfd469501, m7); + step!(step_1!(), 7, 0x698098d8); + step!(step_1!(), 12, 0x8b44f7af); + step!(step_1!(), 17, 0xffff5bb1); + step!(step_1!(), 22, 0x895cd7be); + step!(step_1!(), 7, 0x6b901122); + step!(step_1!(), 12, 0xfd987193); + step!(step_1!(), 17, 0xa679438e, m14); + step!(step_1!(), 22, 0x49b40821); + } + + { + macro_rules! step_2 { + () => (blend!(d, b, c)); + } + + step!(step_2!(), 5, 0xf61e2562, m1); + step!(step_2!(), 9, 0xc040b340, m6); + step!(step_2!(), 14, 0x265e5a51); + step!(step_2!(), 20, 0xe9b6c7aa, m0); + step!(step_2!(), 5, 0xd62f105d, m5); + step!(step_2!(), 9, 0x02441453); + step!(step_2!(), 14, 0xd8a1e681); + step!(step_2!(), 20, 0xe7d3fbc8, m4); + step!(step_2!(), 5, 0x21e1cde6); + step!(step_2!(), 9, 0xc33707d6, m14); + step!(step_2!(), 14, 0xf4d50d87, m3); + step!(step_2!(), 20, 0x455a14ed); + step!(step_2!(), 5, 0xa9e3e905); + step!(step_2!(), 9, 0xfcefa3f8, m2); + step!(step_2!(), 14, 0x676f02d9, m7); + step!(step_2!(), 20, 0x8d2a4c8a); + } + + { + macro_rules! step_3 { + () => (b ^ (c ^ d)); + } + + step!(step_3!(), 4, 0xfffa3942, m5); + step!(step_3!(), 11, 0x8771f681); + step!(step_3!(), 16, 0x6d9d6122); + step!(step_3!(), 23, 0xfde5380c, m14); + step!(step_3!(), 4, 0xa4beea44, m1); + step!(step_3!(), 11, 0x4bdecfa9, m4); + step!(step_3!(), 16, 0xf6bb4b60, m7); + step!(step_3!(), 23, 0xbebfbc70); + step!(step_3!(), 4, 0x289b7ec6); + step!(step_3!(), 11, 0xeaa127fa, m0); + step!(step_3!(), 16, 0xd4ef3085, m3); + step!(step_3!(), 23, 0x04881d05, m6); + step!(step_3!(), 4, 0xd9d4d039); + step!(step_3!(), 11, 0xe6db99e5); + step!(step_3!(), 16, 0x1fa27cf8); + step!(step_3!(), 23, 0xc4ac5665, m2); + } + + { + macro_rules! step_4 { + () => (c ^ (b | !d)); + } + + step!(step_4!(), 6, 0xf4292244, m0); + step!(step_4!(), 10, 0x432aff97, m7); + step!(step_4!(), 15, 0xab9423a7, m14); + step!(step_4!(), 21, 0xfc93a039, m5); + step!(step_4!(), 6, 0x655b59c3); + step!(step_4!(), 10, 0x8f0ccc92, m3); + step!(step_4!(), 15, 0xffeff47d); + step!(step_4!(), 21, 0x85845dd1, m1); + step!(step_4!(), 6, 0x6fa87e4f); + step!(step_4!(), 10, 0xfe2ce6e0); + step!(step_4!(), 15, 0xa3014314, m6); + step!(step_4!(), 21, 0x4e0811a1); + step!(step_4!(), 6, 0xf7537e82, m4); + + // Since we ignore b, c, d values in the end, + // the remaining three iterations are unnecessary, + // as the value of a after iteration 64 is equal + // to the value of b after iteration 61 + a = b + u32x8::splat(0x67452301); + + let mut result: [u32; 8] = [0; 8]; + a.write_to_slice_unaligned(&mut result); + result + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..6875814 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,9 @@ +#![feature(map_into_keys_values)] + +pub mod anagram_finder; +pub mod anagram_logger; +pub mod dictionary_builder; +pub mod hash_computer; +pub mod permutations_cache; +pub mod read_lines; +pub mod vector_alphabet; diff --git a/src/main.rs b/src/main.rs index 74340c1..aa7cb4d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,13 +2,13 @@ use std::cmp; use std::env; +use rayon::prelude::*; -mod anagram_finder; -mod anagram_logger; -mod dictionary_builder; -mod hash_computer; -mod read_lines; -mod vector_alphabet; +use trustpilot_challenge_rust::anagram_finder; +use trustpilot_challenge_rust::anagram_logger; +use trustpilot_challenge_rust::dictionary_builder; +use trustpilot_challenge_rust::hash_computer; +use trustpilot_challenge_rust::read_lines; fn main() { let args: Vec<_> = env::args().collect(); @@ -19,6 +19,12 @@ fn main() { let max_requested_number_of_words = (&args[3]).parse::().unwrap(); let phrase = &args[4]; + /*let message = hash_computer::prepare_messages(phrase); + let hashes = hash_computer::compute_hashes(message, phrase.len()); + for hash in hashes.iter() { + println!("{:#08x}", hash); + }*/ + let phrase_byte_length_without_spaces = phrase.as_bytes().into_iter().filter(|&b| *b != b' ').count(); let max_supported_number_of_words = (hash_computer::MAX_PHRASE_LENGTH - phrase_byte_length_without_spaces) + 1; @@ -31,12 +37,11 @@ fn main() { words.sort(); words.dedup(); - let dictionary = dictionary_builder::build_dictionary(phrase, &words); + let dictionary = dictionary_builder::build_dictionary(phrase, words); for number_of_words in 0..=max_number_of_words { let result = anagram_finder::find_anagrams(&dictionary, number_of_words); - for anagram in result { - println!("{}", anagram_logger::get_anagram_view(anagram, &dictionary)); - } + result.into_par_iter() + .for_each(|anagram| println!("{}", anagram_logger::get_anagram_view(anagram, &dictionary))); } } diff --git a/src/permutations_cache.rs b/src/permutations_cache.rs new file mode 100644 index 0000000..98f198a --- /dev/null +++ b/src/permutations_cache.rs @@ -0,0 +1 @@ +use permutohedron::Heap; \ No newline at end of file diff --git a/src/vector_alphabet.rs b/src/vector_alphabet.rs index 655d6a1..cb41294 100644 --- a/src/vector_alphabet.rs +++ b/src/vector_alphabet.rs @@ -23,7 +23,7 @@ impl Vector { pub fn is_subset_of(&self, other: &Vector) -> bool { let comparison_result = packed_simd::u8x32::gt(self.simd_vector, other.simd_vector); - packed_simd::m8x32::none(comparison_result as packed_simd::m8x32) + packed_simd::m8x32::none(comparison_result) } pub fn safe_substract(&self, vector_to_substract: &Vector) -> Option { diff --git a/tests/hash_computer_test.rs b/tests/hash_computer_test.rs new file mode 100644 index 0000000..7abba33 --- /dev/null +++ b/tests/hash_computer_test.rs @@ -0,0 +1,57 @@ +use packed_simd::u8x32; + +extern crate trustpilot_challenge_rust; +use trustpilot_challenge_rust::hash_computer; + +fn prepare_message(message_string: &str) -> u8x32 { + let mut bytes_static: [u8; 32] = [0; 32]; + let bytes = message_string.as_bytes(); + for i in 0..bytes.len() { + bytes_static[i] = bytes[i]; + } + + bytes_static[bytes.len()] = b' '; + + u8x32::from(bytes_static) +} + +fn prepare_messages(message_strings: [&str; 8]) -> [u8x32; 8] { + let mut result: [u8x32; 8] = [u8x32::splat(0); 8]; + for i in 0..8 { + result[i] = prepare_message(message_strings[i]); + } + result +} + +#[test] +fn it_computes_hashes() { + let messages: [&str; 8] = /*[""; 8];*/[ + "DAPUpOGHw620yalJA0vjFPK7ThgHyAN", + "4xRslaTeBCNyRu2EiIDueEx3BTbIP5H", + "kFPd2zk60eEFpNwgEOZAcyDcxRVv0Y8", + "bm6VQr6w9plie0G8XoOb4wChJXB0vCm", + "gbFrtHcqOTkeG1QxT8YEMSio1ahAYNq", + "T0GmOLB2WH04oIrhB3JCyPHFxI8UOow", + "TWUCy0B0JG5KjQvsu4YUFC5IR5ByS2W", + "VXqOIzYdLIqx6tw8LJbR7SqR5iYgTlQ" + ]; + + let expected: [u128; 8] = [ + 0xC6F9E9B203CEA81A7BA28BE276B96A6F, + 0x45BF15D8B08E1AEADE1305B8E43B8F2C, + 0x7AA53B4627C8DD3714F2874EDE04DA7D, + 0x93B650B474B6FDE6B902A76B1DDA10BB, + 0xBBF89511DAC63A516ADDB9BEC79241A5, + 0x7E53E601351A47500BF4A2B7EAD077C3, + 0xFC6FEB2CC3191198E87ECB9D7626580A, + 0xCED45BA82BD1BDCF546255CB6A530FE3, + ]; + + let messages_simd = prepare_messages(messages); + + let hashes = hash_computer::compute_hashes(messages_simd, messages[0].len()); + + for i in 0..8 { + assert_eq!((expected[i] >> 96) as u32, hashes[i].to_be()); + } +} \ No newline at end of file