highway/x86/
avx.rs

1#![allow(unsafe_code)]
2use super::{v2x64u::V2x64U, v4x64u::V4x64U};
3use crate::internal::unordered_load3;
4use crate::internal::{HashPacket, PACKET_SIZE};
5use crate::key::Key;
6use crate::traits::HighwayHash;
7use crate::PortableHash;
8use core::arch::x86_64::*;
9
10/// AVX empowered implementation that will only work on `x86_64` with avx2 enabled at the CPU
11/// level.
12#[derive(Debug, Default, Clone)]
13pub struct AvxHash {
14    v0: V4x64U,
15    v1: V4x64U,
16    mul0: V4x64U,
17    mul1: V4x64U,
18    buffer: HashPacket,
19}
20
21impl HighwayHash for AvxHash {
22    #[inline]
23    fn append(&mut self, data: &[u8]) {
24        unsafe {
25            self.append(data);
26        }
27    }
28
29    #[inline]
30    fn finalize64(mut self) -> u64 {
31        unsafe { Self::finalize64(&mut self) }
32    }
33
34    #[inline]
35    fn finalize128(mut self) -> [u64; 2] {
36        unsafe { Self::finalize128(&mut self) }
37    }
38
39    #[inline]
40    fn finalize256(mut self) -> [u64; 4] {
41        unsafe { Self::finalize256(&mut self) }
42    }
43
44    #[inline]
45    fn checkpoint(&self) -> [u8; 164] {
46        PortableHash {
47            v0: unsafe { self.v0.as_arr() },
48            v1: unsafe { self.v1.as_arr() },
49            mul0: unsafe { self.mul0.as_arr() },
50            mul1: unsafe { self.mul1.as_arr() },
51            buffer: self.buffer,
52        }
53        .checkpoint()
54    }
55}
56
57impl AvxHash {
58    /// Creates a new `AvxHash` while circumventing the runtime check for avx2.
59    ///
60    /// # Safety
61    ///
62    /// If called on a machine without avx2, a segfault will occur. Only use if you have
63    /// control over the deployment environment and have either benchmarked that the runtime
64    /// check is significant or are unable to check for avx2 capabilities
65    #[must_use]
66    #[target_feature(enable = "avx2")]
67    pub unsafe fn force_new(key: Key) -> Self {
68        let mul0 = V4x64U::new(
69            0x243f_6a88_85a3_08d3,
70            0x1319_8a2e_0370_7344,
71            0xa409_3822_299f_31d0,
72            0xdbe6_d5d5_fe4c_ce2f,
73        );
74        let mul1 = V4x64U::new(
75            0x4528_21e6_38d0_1377,
76            0xbe54_66cf_34e9_0c6c,
77            0xc0ac_f169_b5f1_8a8c,
78            0x3bd3_9e10_cb0e_f593,
79        );
80
81        let key = V4x64U::from(_mm256_load_si256(key.0.as_ptr().cast::<__m256i>()));
82
83        AvxHash {
84            v0: key ^ mul0,
85            v1: key.rotate_by_32() ^ mul1,
86            mul0,
87            mul1,
88            buffer: HashPacket::default(),
89        }
90    }
91
92    /// Creates a new `AvxHash` if the avx2 feature is detected.
93    #[must_use]
94    pub fn new(key: Key) -> Option<Self> {
95        #[cfg(feature = "std")]
96        {
97            if is_x86_feature_detected!("avx2") {
98                Some(unsafe { Self::force_new(key) })
99            } else {
100                None
101            }
102        }
103
104        #[cfg(not(feature = "std"))]
105        {
106            let _key = key;
107            None
108        }
109    }
110
111    /// Creates a new `AvxHash` from a checkpoint while circumventing the runtime check for avx2.
112    ///
113    /// # Safety
114    ///
115    /// See [`Self::force_new`] for safety concerns.
116    #[must_use]
117    #[target_feature(enable = "avx2")]
118    pub unsafe fn force_from_checkpoint(data: [u8; 164]) -> Self {
119        let portable = PortableHash::from_checkpoint(data);
120        AvxHash {
121            v0: V4x64U::new(
122                portable.v0[3],
123                portable.v0[2],
124                portable.v0[1],
125                portable.v0[0],
126            ),
127            v1: V4x64U::new(
128                portable.v1[3],
129                portable.v1[2],
130                portable.v1[1],
131                portable.v1[0],
132            ),
133            mul0: V4x64U::new(
134                portable.mul0[3],
135                portable.mul0[2],
136                portable.mul0[1],
137                portable.mul0[0],
138            ),
139            mul1: V4x64U::new(
140                portable.mul1[3],
141                portable.mul1[2],
142                portable.mul1[1],
143                portable.mul1[0],
144            ),
145            buffer: portable.buffer,
146        }
147    }
148
149    /// Creates a new `AvxHash` from a checkpoint if the avx2 feature is detected.
150    #[must_use]
151    pub fn from_checkpoint(data: [u8; 164]) -> Option<Self> {
152        #[cfg(feature = "std")]
153        {
154            if is_x86_feature_detected!("avx2") {
155                Some(unsafe { Self::force_from_checkpoint(data) })
156            } else {
157                None
158            }
159        }
160
161        #[cfg(not(feature = "std"))]
162        {
163            let _ = data;
164            None
165        }
166    }
167
168    #[target_feature(enable = "avx2")]
169    pub(crate) unsafe fn finalize64(&mut self) -> u64 {
170        if !self.buffer.is_empty() {
171            self.update_remainder();
172        }
173
174        for _i in 0..4 {
175            let permuted = AvxHash::permute(&self.v0);
176            self.update(permuted);
177        }
178
179        let sum0 = V2x64U::from(_mm256_castsi256_si128((self.v0 + self.mul0).0));
180        let sum1 = V2x64U::from(_mm256_castsi256_si128((self.v1 + self.mul1).0));
181        let hash = sum0 + sum1;
182        let mut result: u64 = 0;
183        // Each lane is sufficiently mixed, so just truncate to 64 bits.
184        _mm_storel_epi64(core::ptr::addr_of_mut!(result).cast::<__m128i>(), hash.0);
185        result
186    }
187
188    #[target_feature(enable = "avx2")]
189    pub(crate) unsafe fn finalize128(&mut self) -> [u64; 2] {
190        if !self.buffer.is_empty() {
191            self.update_remainder();
192        }
193
194        for _i in 0..6 {
195            let permuted = AvxHash::permute(&self.v0);
196            self.update(permuted);
197        }
198
199        let sum0 = V2x64U::from(_mm256_castsi256_si128((self.v0 + self.mul0).0));
200        let sum1 = V2x64U::from(_mm256_extracti128_si256((self.v1 + self.mul1).0, 1));
201        let hash = sum0 + sum1;
202        let mut result: [u64; 2] = [0; 2];
203        _mm_storeu_si128(result.as_mut_ptr().cast::<__m128i>(), hash.0);
204        result
205    }
206
207    #[target_feature(enable = "avx2")]
208    pub(crate) unsafe fn finalize256(&mut self) -> [u64; 4] {
209        if !self.buffer.is_empty() {
210            self.update_remainder();
211        }
212
213        for _i in 0..10 {
214            let permuted = AvxHash::permute(&self.v0);
215            self.update(permuted);
216        }
217
218        let sum0 = self.v0 + self.mul0;
219        let sum1 = self.v1 + self.mul1;
220        let hash = AvxHash::modular_reduction(&sum1, &sum0);
221        let mut result: [u64; 4] = [0; 4];
222        _mm256_storeu_si256(result.as_mut_ptr().cast::<__m256i>(), hash.0);
223        result
224    }
225
226    #[inline]
227    #[target_feature(enable = "avx2")]
228    unsafe fn data_to_lanes(packet: &[u8]) -> V4x64U {
229        V4x64U::from(_mm256_loadu_si256(packet.as_ptr().cast::<__m256i>()))
230    }
231
232    #[target_feature(enable = "avx2")]
233    unsafe fn remainder(bytes: &[u8]) -> V4x64U {
234        let size_mod32 = bytes.len();
235        let size256 = _mm256_broadcastd_epi32(_mm_cvtsi64_si128(size_mod32 as i64));
236        let size_mod4 = size_mod32 & 3;
237        let size = _mm256_castsi256_si128(size256);
238        if size_mod32 & 16 != 0 {
239            let packetL = _mm_load_si128(bytes.as_ptr().cast::<__m128i>());
240            let int_mask = _mm_cmpgt_epi32(size, _mm_set_epi32(31, 27, 23, 19));
241            let int_lanes = _mm_maskload_epi32(bytes.as_ptr().offset(16).cast::<i32>(), int_mask);
242            let remainder = &bytes[(size_mod32 & !3) + size_mod4 - 4..];
243            let last4 =
244                i32::from_le_bytes([remainder[0], remainder[1], remainder[2], remainder[3]]);
245            let packetH = _mm_insert_epi32(int_lanes, last4, 3);
246            let packetL256 = _mm256_castsi128_si256(packetL);
247            let packet = _mm256_inserti128_si256(packetL256, packetH, 1);
248            V4x64U::from(packet)
249        } else {
250            let int_mask = _mm_cmpgt_epi32(size, _mm_set_epi32(15, 11, 7, 3));
251            let packetL = _mm_maskload_epi32(bytes.as_ptr().cast::<i32>(), int_mask);
252            let remainder = &bytes[size_mod32 & !3..];
253            let last3 = unordered_load3(remainder);
254            let packetH = _mm_cvtsi64_si128(last3 as i64);
255            let packetL256 = _mm256_castsi128_si256(packetL);
256            let packet = _mm256_inserti128_si256(packetL256, packetH, 1);
257            V4x64U::from(packet)
258        }
259    }
260
261    #[target_feature(enable = "avx2")]
262    unsafe fn update_remainder(&mut self) {
263        let size = self.buffer.len();
264        let size256 = _mm256_broadcastd_epi32(_mm_cvtsi64_si128(size as i64));
265        self.v0 += V4x64U::from(size256);
266        let shifted_left = V4x64U::from(_mm256_sllv_epi32(self.v1.0, size256));
267        let tip = _mm256_broadcastd_epi32(_mm_cvtsi32_si128(32));
268        let shifted_right =
269            V4x64U::from(_mm256_srlv_epi32(self.v1.0, _mm256_sub_epi32(tip, size256)));
270        self.v1 = shifted_left | shifted_right;
271
272        let packet = AvxHash::remainder(self.buffer.as_slice());
273        self.update(packet);
274    }
275
276    #[target_feature(enable = "avx2")]
277    unsafe fn zipper_merge(v: &V4x64U) -> V4x64U {
278        let hi = 0x0708_0609_0D0A_040B;
279        let lo = 0x000F_010E_0502_0C03;
280        v.shuffle(&V4x64U::new(hi, lo, hi, lo))
281    }
282
283    #[target_feature(enable = "avx2")]
284    unsafe fn update(&mut self, packet: V4x64U) {
285        self.v1 += packet;
286        self.v1 += self.mul0;
287        self.mul0 ^= self.v1.mul_low32(&self.v0.shr_by_32());
288        self.v0 += self.mul1;
289        self.mul1 ^= self.v0.mul_low32(&self.v1.shr_by_32());
290        self.v0 += AvxHash::zipper_merge(&self.v1);
291        self.v1 += AvxHash::zipper_merge(&self.v0);
292    }
293
294    #[target_feature(enable = "avx2")]
295    unsafe fn permute(v: &V4x64U) -> V4x64U {
296        let indices = V4x64U::new(
297            0x0000_0002_0000_0003,
298            0x0000_0000_0000_0001,
299            0x0000_0006_0000_0007,
300            0x0000_0004_0000_0005,
301        );
302
303        V4x64U::from(_mm256_permutevar8x32_epi32(v.0, indices.0))
304    }
305
306    #[target_feature(enable = "avx2")]
307    unsafe fn modular_reduction(x: &V4x64U, init: &V4x64U) -> V4x64U {
308        let top_bits2 = V4x64U::from(_mm256_srli_epi64(x.0, 62));
309        let ones = V4x64U::from(_mm256_cmpeq_epi64(x.0, x.0));
310        let shifted1_unmasked = *x + *x;
311        let top_bits1 = V4x64U::from(_mm256_srli_epi64(x.0, 63));
312        let upper_8bytes = V4x64U::from(_mm256_slli_si256(ones.0, 8));
313        let shifted2 = shifted1_unmasked + shifted1_unmasked;
314        let upper_bit_of_128 = V4x64U::from(_mm256_slli_epi64(upper_8bytes.0, 63));
315        let zero = V4x64U::from(_mm256_setzero_si256());
316        let new_low_bits2 = V4x64U::from(_mm256_unpacklo_epi64(zero.0, top_bits2.0));
317        let shifted1 = shifted1_unmasked.and_not(&upper_bit_of_128);
318        let new_low_bits1 = V4x64U::from(_mm256_unpacklo_epi64(zero.0, top_bits1.0));
319
320        *init ^ shifted2 ^ new_low_bits2 ^ shifted1 ^ new_low_bits1
321    }
322
323    #[target_feature(enable = "avx2")]
324    unsafe fn append(&mut self, data: &[u8]) {
325        if self.buffer.is_empty() {
326            let mut chunks = data.chunks_exact(PACKET_SIZE);
327            for chunk in chunks.by_ref() {
328                self.update(Self::data_to_lanes(chunk));
329            }
330            self.buffer.set_to(chunks.remainder());
331        } else if let Some(tail) = self.buffer.fill(data) {
332            self.update(Self::data_to_lanes(self.buffer.inner()));
333            let mut chunks = tail.chunks_exact(PACKET_SIZE);
334            for chunk in chunks.by_ref() {
335                self.update(Self::data_to_lanes(chunk));
336            }
337
338            self.buffer.set_to(chunks.remainder());
339        }
340    }
341}
342
343impl_write!(AvxHash);
344impl_hasher!(AvxHash);