updated readme; implemented MD5 computation using SIMD

main
Inga 🏳‍🌈 3 years ago
parent 57f3877378
commit d6ba6a9199
  1. 7
      Cargo.toml
  2. 87
      README.md
  3. 6
      src/anagram_logger.rs
  4. 37
      src/dictionary_builder.rs
  5. 184
      src/hash_computer.rs
  6. 9
      src/lib.rs
  7. 25
      src/main.rs
  8. 1
      src/permutations_cache.rs
  9. 2
      src/vector_alphabet.rs
  10. 57
      tests/hash_computer_test.rs

@ -1,5 +1,5 @@
[package]
name = "hello_cargo"
name = "trustpilot_challenge_rust"
version = "0.1.0"
authors = ["inga-lovinde <52715130+inga-lovinde@users.noreply.github.com>"]
edition = "2018"
@ -7,4 +7,7 @@ edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
packed_simd = { version = "0.3.4", package = "packed_simd_2" }
crunchy = "0.2.2"
packed_simd = { version = "0.3.4", package = "packed_simd_2", features = ["into_bits"] }
permutohedron = "0.2.4"
rayon = "1.5.0"

@ -1,4 +1,4 @@
# TrustPilotChallengeRust
# TrustPilotChallengeRust
TrustPilot had this challenge several years ago
(http://followthewhiterabbit.trustpilot.com/)
@ -26,7 +26,12 @@ and does not yet do actual MD5 calculation.
## Algorithm description
Notably this solution does not involve string concatenation; strings are only concatenated for debugging purposes.
Notably this solution does not involve string concatenation;
strings are only concatenated for debugging purposes.
It also computes eight MD5 hashes at a time *per thread*
(that is, 128 MD5 hashes at once on a modern 8-core CPU),
with some further optimizations which further shave off
several percents from MD5 computation time.
We could split the problem into three parts: finding all anagrams
(up to words reordering and replacing some of the words with their single-word anagrams),
@ -70,6 +75,84 @@ could be further optimized, resulting in the following algorithm:
* W might be one element of a target subset, and the remaining elements could be found
by solving the task 2 for N-1, P-W and position of W in the list of vectors.
### Finding all anagrams, pt. 2
In the previous step, we just found all unique tuples of vectors with vectors ordered by norm decreasing
such that they give the required vector.
Now we need to convert these back to phrases.
If for every vector there was only one word which produces that vector,
and if all the vectors in a tuple were different,
we could just look at all their permutations and get n! solutions from a single tuple.
But a tuple can contain several copies of one vector,
and there could be several different words corresponding to one vector.
Computing all possible permutations would result in duplicate solutions
and too much unneccessary work.
So we could:
1. Substitute all possible word values for every vector, getting several (ordered) word solutions;
2. Apply all possible permutations to them such that, if vectors k and k+1 were the same in the vector solution,
word k should go before word k+1 in the word solution
(because the solution where word k goes after word k+1 is already obtained by a different substitution on step 1).
Every string shorter than 32 bytes could be represented as a single u8x32 AVX2 register
(with the remaining bytes filled with zeroes).
Concatenating strings could be as simple as XORing the vectors, shifted appropriately.
For example, to create `"a b "` string we would need to compute `"a " xor "␀␀b "`,
which is done in a single cycle on a modern CPU, provided that we have both vectors ready.
This is as opposed to concatenating strings which would require allocating a new string on the heap
and copying the data.
So we could just store all of the original words as such a vectors for all possible offsets
(along with trailing spaces), and when we need to compute a phrase consisting of the word x and the word y,
just do something along the lines of `get_register(x, 0) xor get_register(y, x.length)`
### Computing hashes
MD5 works on input messages in 64 byte blocks; for short strings (shorter than 55 bytes)
it only uses a single blocks: 0x80 byte is appended to the message, then it is padded to 56 bytes with zeroes,
and then the total length of the string in bits is appended as 64-bit number.
So short phrases (shorter than 31 bytes) could be represented with two AVX2 registers:
one containing the phrase itself with the trailing 0x80, and another containing 24 zeroes
and 64-bit length of the phrase in bits (which is the number of non-space bytes
plus the number of words, times 8).
For its internal state, MD5 has four 32-bit variables (u32).
This means that with AVX2, we can use the same operations on 256-bit registers
(u32x8) and compute eight hashes at the same time in a single thread.
MD5 breaks input chunks into 16 u32 words (and for short phrases chunks 8-14 are always zero),
so our algorithm could receive 8x256-bit values and the phrase length,
rearrange these into 9 256-bit values (8 obtained by transposing the original 8 as 8x8 matrix of u32,
and ninth being 8 copies of the phrase length in bits),
and then implement MD5 algorithms using these 9 values as input words 0..7, 15
(substituting 0 as input words 8..14).
That way, MD5 performance would be increased 8x compared to the ordinary library function
which does not use SIMD.
As a minor additional optimization, we could only compute the first u32 part of the MD5 hash
(because we don't need to compute entire hashes for all possible anagrams,
we only need to find anagrams which match the requested hashes.
That way, we'll save some unneeded steps in MD5 computation,
and we also won't have to convert hashes back to separate variables:
we could just compare u32x8 holding the first parts of hashes for eight different anagrams
with u32x8 holding eight copies of the first part of the requested hash.
That way, we'll only have one comparison instead of eight,
at the cost of rare false positives which occur on average with 1/2^29 probability
(1/2^32 chance that a random u32 matches the requested u32, for every of the eight anagrams).
If there is such a semi-match (that is, one of the eight anagrams produces a hash
with first 32 bits matching first 32 bits of the requested hash), we could just
compute MD5 for every of the eight anagrams in the ordinary way and
to compare the whole resulting hashes with the requested ones;
as this is extremely rare (once every 1/29th calls to SIMD MD5 function),
it will not severely affect performance.
## How to run
How to run to solve the original task for three-word anagrams:

@ -5,11 +5,11 @@ pub fn get_anagram_view(anagram: Vec<usize>, dictionary: &dictionary_builder::Di
.map(|&index| {
let word_options = &dictionary.words[index];
if word_options.len() == 1 {
word_options[0].clone()
word_options[0].word.clone()
} else {
format!("[{}]", word_options.join(","))
format!("[{}]", word_options.iter().map(|word_info| word_info.word.clone()).collect::<Vec<_>>().join(","))
}
})
.collect::<Vec<_>>()
.join(" ")
}
}

@ -1,13 +1,44 @@
use std::collections::HashMap;
use packed_simd;
use crate::vector_alphabet;
pub struct WordInfo {
simd_words: [packed_simd::u8x32; 32],
pub length: usize,
pub word: String,
}
impl WordInfo {
fn new(word: String) -> WordInfo {
let mut byte_array: [u8; 64] = [0; 64];
let bytes = word.as_bytes();
let length = bytes.len();
byte_array[32 + length] = b' ';
for i in 0..length {
byte_array[32 + i] = bytes[i];
}
let simd_word_zero: packed_simd::u8x32 = packed_simd::u8x32::from_slice_unaligned(&[0; 32]);
let mut simd_words: [packed_simd::u8x32; 32] = [simd_word_zero; 32];
for i in 0..31 {
simd_words[i] = packed_simd::u8x32::from_slice_unaligned(&byte_array[32-i..64-i]);
}
WordInfo {
simd_words,
length,
word,
}
}
}
pub struct Dictionary {
pub phrase_vector: vector_alphabet::Vector,
pub vectors: Vec<vector_alphabet::Vector>,
pub words: Vec<Vec<String>>,
pub words: Vec<Vec<WordInfo>>,
}
pub fn build_dictionary(phrase: &String, unique_words: &[String]) -> Dictionary {
pub fn build_dictionary(phrase: &String, unique_words: Vec<String>) -> Dictionary {
let alphabet = vector_alphabet::Alphabet::new(phrase).unwrap();
let phrase_with_metadata = alphabet.vectorize(phrase).unwrap();
@ -35,7 +66,7 @@ pub fn build_dictionary(phrase: &String, unique_words: &[String]) -> Dictionary
let mut words_by_vectors: HashMap<_, _> = HashMap::new();
for (word, vector_with_metadata) in words_with_vectors {
let (_, words_for_vector) = words_by_vectors.entry(vector_with_metadata.key).or_insert((vector_with_metadata.vector, vec![]));
words_for_vector.push(word.clone());
words_for_vector.push(WordInfo::new(word));
}
let mut words_by_vectors: Vec<_> = words_by_vectors.into_values().collect();

@ -1 +1,183 @@
pub const MAX_PHRASE_LENGTH: usize = 27;
use packed_simd::FromBits;
use packed_simd::u32x8;
use packed_simd::u8x32;
pub const MAX_PHRASE_LENGTH: usize = 31;
#[allow(unused_assignments)]
pub fn compute_hashes(messages: [u8x32; 8], messages_length: usize) -> [u32; 8] {
let mut a: u32x8 = u32x8::splat(0x67452301);
let mut b: u32x8 = u32x8::splat(0xefcdab89);
let mut c: u32x8 = u32x8::splat(0x98badcfe);
let mut d: u32x8 = u32x8::splat(0x10325476);
let trailer = u8x32::splat(0).replace(messages_length, b' ' ^ 0x80);
let mut messages_bytes: [u32; 64] = [0; 64];
{
macro_rules! write_bytes {
($i: expr) => {
u32x8::from_bits(messages[$i] ^ trailer).write_to_slice_unaligned(&mut messages_bytes[($i*8)..])
}
}
write_bytes!(0);
write_bytes!(1);
write_bytes!(2);
write_bytes!(3);
write_bytes!(4);
write_bytes!(5);
write_bytes!(6);
write_bytes!(7);
}
macro_rules! get_m_value {
($i: expr) => {
u32x8::new(
messages_bytes[0*8 + $i],
messages_bytes[1*8 + $i],
messages_bytes[2*8 + $i],
messages_bytes[3*8 + $i],
messages_bytes[4*8 + $i],
messages_bytes[5*8 + $i],
messages_bytes[6*8 + $i],
messages_bytes[7*8 + $i],
)
};
}
let m0: u32x8 = get_m_value!(0);
let m1: u32x8 = get_m_value!(1);
let m2: u32x8 = get_m_value!(2);
let m3: u32x8 = get_m_value!(3);
let m4: u32x8 = get_m_value!(4);
let m5: u32x8 = get_m_value!(5);
let m6: u32x8 = get_m_value!(6);
let m7: u32x8 = get_m_value!(7);
let m14: u32x8 = u32x8::splat((messages_length as u32) * 8);
macro_rules! lrot {
($f: expr, $s: expr) => (($f << $s) | ($f >> (32-$s)));
}
macro_rules! blend {
($mask: expr, $a: expr, $b: expr) => {
// andnot (_mm256_andnot_si256) is not implemented in packed_simd
($a & $mask) | ($b & !$mask)
}
}
macro_rules! step {
($f: expr, $s: expr, $k: expr, $m: expr) => {
let f = $f + a + u32x8::splat($k) + $m;
a = d;
d = c;
c = b;
b = b + lrot!(f, $s);
};
($f: expr, $s: expr, $k: expr) => {
let f = $f + a + u32x8::splat($k);
a = d;
d = c;
c = b;
b = b + lrot!(f, $s);
};
}
{
macro_rules! step_1 {
() => (blend!(b, c, d));
}
step!(step_1!(), 7, 0xd76aa478, m0);
step!(step_1!(), 12, 0xe8c7b756, m1);
step!(step_1!(), 17, 0x242070db, m2);
step!(step_1!(), 22, 0xc1bdceee, m3);
step!(step_1!(), 7, 0xf57c0faf, m4);
step!(step_1!(), 12, 0x4787c62a, m5);
step!(step_1!(), 17, 0xa8304613, m6);
step!(step_1!(), 22, 0xfd469501, m7);
step!(step_1!(), 7, 0x698098d8);
step!(step_1!(), 12, 0x8b44f7af);
step!(step_1!(), 17, 0xffff5bb1);
step!(step_1!(), 22, 0x895cd7be);
step!(step_1!(), 7, 0x6b901122);
step!(step_1!(), 12, 0xfd987193);
step!(step_1!(), 17, 0xa679438e, m14);
step!(step_1!(), 22, 0x49b40821);
}
{
macro_rules! step_2 {
() => (blend!(d, b, c));
}
step!(step_2!(), 5, 0xf61e2562, m1);
step!(step_2!(), 9, 0xc040b340, m6);
step!(step_2!(), 14, 0x265e5a51);
step!(step_2!(), 20, 0xe9b6c7aa, m0);
step!(step_2!(), 5, 0xd62f105d, m5);
step!(step_2!(), 9, 0x02441453);
step!(step_2!(), 14, 0xd8a1e681);
step!(step_2!(), 20, 0xe7d3fbc8, m4);
step!(step_2!(), 5, 0x21e1cde6);
step!(step_2!(), 9, 0xc33707d6, m14);
step!(step_2!(), 14, 0xf4d50d87, m3);
step!(step_2!(), 20, 0x455a14ed);
step!(step_2!(), 5, 0xa9e3e905);
step!(step_2!(), 9, 0xfcefa3f8, m2);
step!(step_2!(), 14, 0x676f02d9, m7);
step!(step_2!(), 20, 0x8d2a4c8a);
}
{
macro_rules! step_3 {
() => (b ^ (c ^ d));
}
step!(step_3!(), 4, 0xfffa3942, m5);
step!(step_3!(), 11, 0x8771f681);
step!(step_3!(), 16, 0x6d9d6122);
step!(step_3!(), 23, 0xfde5380c, m14);
step!(step_3!(), 4, 0xa4beea44, m1);
step!(step_3!(), 11, 0x4bdecfa9, m4);
step!(step_3!(), 16, 0xf6bb4b60, m7);
step!(step_3!(), 23, 0xbebfbc70);
step!(step_3!(), 4, 0x289b7ec6);
step!(step_3!(), 11, 0xeaa127fa, m0);
step!(step_3!(), 16, 0xd4ef3085, m3);
step!(step_3!(), 23, 0x04881d05, m6);
step!(step_3!(), 4, 0xd9d4d039);
step!(step_3!(), 11, 0xe6db99e5);
step!(step_3!(), 16, 0x1fa27cf8);
step!(step_3!(), 23, 0xc4ac5665, m2);
}
{
macro_rules! step_4 {
() => (c ^ (b | !d));
}
step!(step_4!(), 6, 0xf4292244, m0);
step!(step_4!(), 10, 0x432aff97, m7);
step!(step_4!(), 15, 0xab9423a7, m14);
step!(step_4!(), 21, 0xfc93a039, m5);
step!(step_4!(), 6, 0x655b59c3);
step!(step_4!(), 10, 0x8f0ccc92, m3);
step!(step_4!(), 15, 0xffeff47d);
step!(step_4!(), 21, 0x85845dd1, m1);
step!(step_4!(), 6, 0x6fa87e4f);
step!(step_4!(), 10, 0xfe2ce6e0);
step!(step_4!(), 15, 0xa3014314, m6);
step!(step_4!(), 21, 0x4e0811a1);
step!(step_4!(), 6, 0xf7537e82, m4);
// Since we ignore b, c, d values in the end,
// the remaining three iterations are unnecessary,
// as the value of a after iteration 64 is equal
// to the value of b after iteration 61
a = b + u32x8::splat(0x67452301);
let mut result: [u32; 8] = [0; 8];
a.write_to_slice_unaligned(&mut result);
result
}
}

@ -0,0 +1,9 @@
#![feature(map_into_keys_values)]
pub mod anagram_finder;
pub mod anagram_logger;
pub mod dictionary_builder;
pub mod hash_computer;
pub mod permutations_cache;
pub mod read_lines;
pub mod vector_alphabet;

@ -2,13 +2,13 @@
use std::cmp;
use std::env;
use rayon::prelude::*;
mod anagram_finder;
mod anagram_logger;
mod dictionary_builder;
mod hash_computer;
mod read_lines;
mod vector_alphabet;
use trustpilot_challenge_rust::anagram_finder;
use trustpilot_challenge_rust::anagram_logger;
use trustpilot_challenge_rust::dictionary_builder;
use trustpilot_challenge_rust::hash_computer;
use trustpilot_challenge_rust::read_lines;
fn main() {
let args: Vec<_> = env::args().collect();
@ -19,6 +19,12 @@ fn main() {
let max_requested_number_of_words = (&args[3]).parse::<usize>().unwrap();
let phrase = &args[4];
/*let message = hash_computer::prepare_messages(phrase);
let hashes = hash_computer::compute_hashes(message, phrase.len());
for hash in hashes.iter() {
println!("{:#08x}", hash);
}*/
let phrase_byte_length_without_spaces = phrase.as_bytes().into_iter().filter(|&b| *b != b' ').count();
let max_supported_number_of_words = (hash_computer::MAX_PHRASE_LENGTH - phrase_byte_length_without_spaces) + 1;
@ -31,12 +37,11 @@ fn main() {
words.sort();
words.dedup();
let dictionary = dictionary_builder::build_dictionary(phrase, &words);
let dictionary = dictionary_builder::build_dictionary(phrase, words);
for number_of_words in 0..=max_number_of_words {
let result = anagram_finder::find_anagrams(&dictionary, number_of_words);
for anagram in result {
println!("{}", anagram_logger::get_anagram_view(anagram, &dictionary));
}
result.into_par_iter()
.for_each(|anagram| println!("{}", anagram_logger::get_anagram_view(anagram, &dictionary)));
}
}

@ -0,0 +1 @@
use permutohedron::Heap;

@ -23,7 +23,7 @@ impl Vector {
pub fn is_subset_of(&self, other: &Vector) -> bool {
let comparison_result = packed_simd::u8x32::gt(self.simd_vector, other.simd_vector);
packed_simd::m8x32::none(comparison_result as packed_simd::m8x32)
packed_simd::m8x32::none(comparison_result)
}
pub fn safe_substract(&self, vector_to_substract: &Vector) -> Option<Vector> {

@ -0,0 +1,57 @@
use packed_simd::u8x32;
extern crate trustpilot_challenge_rust;
use trustpilot_challenge_rust::hash_computer;
fn prepare_message(message_string: &str) -> u8x32 {
let mut bytes_static: [u8; 32] = [0; 32];
let bytes = message_string.as_bytes();
for i in 0..bytes.len() {
bytes_static[i] = bytes[i];
}
bytes_static[bytes.len()] = b' ';
u8x32::from(bytes_static)
}
fn prepare_messages(message_strings: [&str; 8]) -> [u8x32; 8] {
let mut result: [u8x32; 8] = [u8x32::splat(0); 8];
for i in 0..8 {
result[i] = prepare_message(message_strings[i]);
}
result
}
#[test]
fn it_computes_hashes() {
let messages: [&str; 8] = /*[""; 8];*/[
"DAPUpOGHw620yalJA0vjFPK7ThgHyAN",
"4xRslaTeBCNyRu2EiIDueEx3BTbIP5H",
"kFPd2zk60eEFpNwgEOZAcyDcxRVv0Y8",
"bm6VQr6w9plie0G8XoOb4wChJXB0vCm",
"gbFrtHcqOTkeG1QxT8YEMSio1ahAYNq",
"T0GmOLB2WH04oIrhB3JCyPHFxI8UOow",
"TWUCy0B0JG5KjQvsu4YUFC5IR5ByS2W",
"VXqOIzYdLIqx6tw8LJbR7SqR5iYgTlQ"
];
let expected: [u128; 8] = [
0xC6F9E9B203CEA81A7BA28BE276B96A6F,
0x45BF15D8B08E1AEADE1305B8E43B8F2C,
0x7AA53B4627C8DD3714F2874EDE04DA7D,
0x93B650B474B6FDE6B902A76B1DDA10BB,
0xBBF89511DAC63A516ADDB9BEC79241A5,
0x7E53E601351A47500BF4A2B7EAD077C3,
0xFC6FEB2CC3191198E87ECB9D7626580A,
0xCED45BA82BD1BDCF546255CB6A530FE3,
];
let messages_simd = prepare_messages(messages);
let hashes = hash_computer::compute_hashes(messages_simd, messages[0].len());
for i in 0..8 {
assert_eq!((expected[i] >> 96) as u32, hashes[i].to_be());
}
}
Loading…
Cancel
Save