highway/x86/
sse.rs

1#![allow(unsafe_code)]
2use super::v2x64u::V2x64U;
3use crate::internal::unordered_load3;
4use crate::internal::{HashPacket, PACKET_SIZE};
5use crate::key::Key;
6use crate::traits::HighwayHash;
7use crate::PortableHash;
8use core::arch::x86_64::*;
9
10/// SSE empowered implementation that will only work on `x86_64` with sse 4.1 enabled at the CPU
11/// level.
12#[derive(Debug, Default, Clone)]
13pub struct SseHash {
14    v0L: V2x64U,
15    v0H: V2x64U,
16    v1L: V2x64U,
17    v1H: V2x64U,
18    mul0L: V2x64U,
19    mul0H: V2x64U,
20    mul1L: V2x64U,
21    mul1H: V2x64U,
22    buffer: HashPacket,
23}
24
25impl HighwayHash for SseHash {
26    #[inline]
27    fn append(&mut self, data: &[u8]) {
28        unsafe {
29            self.append(data);
30        }
31    }
32
33    #[inline]
34    fn finalize64(mut self) -> u64 {
35        unsafe { Self::finalize64(&mut self) }
36    }
37
38    #[inline]
39    fn finalize128(mut self) -> [u64; 2] {
40        unsafe { Self::finalize128(&mut self) }
41    }
42
43    #[inline]
44    fn finalize256(mut self) -> [u64; 4] {
45        unsafe { Self::finalize256(&mut self) }
46    }
47
48    #[inline]
49    fn checkpoint(&self) -> [u8; 164] {
50        let mut v0 = [0u64; 4];
51        v0[..2].copy_from_slice(unsafe { &self.v0L.as_arr() });
52        v0[2..].copy_from_slice(unsafe { &self.v0H.as_arr() });
53
54        let mut v1 = [0u64; 4];
55        v1[..2].copy_from_slice(unsafe { &self.v1L.as_arr() });
56        v1[2..].copy_from_slice(unsafe { &self.v1H.as_arr() });
57
58        let mut mul0 = [0u64; 4];
59        mul0[..2].copy_from_slice(unsafe { &self.mul0L.as_arr() });
60        mul0[2..].copy_from_slice(unsafe { &self.mul0H.as_arr() });
61
62        let mut mul1 = [0u64; 4];
63        mul1[..2].copy_from_slice(unsafe { &self.mul1L.as_arr() });
64        mul1[2..].copy_from_slice(unsafe { &self.mul1H.as_arr() });
65
66        PortableHash {
67            v0,
68            v1,
69            mul0,
70            mul1,
71            buffer: self.buffer,
72        }
73        .checkpoint()
74    }
75}
76
77impl SseHash {
78    /// Creates a new `SseHash` while circumventing the runtime check for sse4.1.
79    ///
80    /// # Safety
81    ///
82    /// If called on a machine without sse4.1, a segfault will occur. Only use if you have
83    /// control over the deployment environment and have either benchmarked that the runtime
84    /// check is significant or are unable to check for sse4.1 capabilities
85    #[must_use]
86    #[target_feature(enable = "sse4.1")]
87    pub unsafe fn force_new(key: Key) -> Self {
88        let init0L = V2x64U::new(0xa409_3822_299f_31d0, 0xdbe6_d5d5_fe4c_ce2f);
89        let init0H = V2x64U::new(0x243f_6a88_85a3_08d3, 0x1319_8a2e_0370_7344);
90        let init1L = V2x64U::new(0xc0ac_f169_b5f1_8a8c, 0x3bd3_9e10_cb0e_f593);
91        let init1H = V2x64U::new(0x4528_21e6_38d0_1377, 0xbe54_66cf_34e9_0c6c);
92        let key_ptr = key.0.as_ptr().cast::<__m128i>();
93        let keyL = V2x64U::from(_mm_loadu_si128(key_ptr));
94        let keyH = V2x64U::from(_mm_loadu_si128(key_ptr.add(1)));
95
96        SseHash {
97            v0L: keyL ^ init0L,
98            v0H: keyH ^ init0H,
99            v1L: keyL.rotate_by_32() ^ init1L,
100            v1H: keyH.rotate_by_32() ^ init1H,
101            mul0L: init0L,
102            mul0H: init0H,
103            mul1L: init1L,
104            mul1H: init1H,
105            buffer: HashPacket::default(),
106        }
107    }
108
109    /// Create a new `SseHash` if the sse4.1 feature is detected
110    #[must_use]
111    pub fn new(key: Key) -> Option<Self> {
112        #[cfg(feature = "std")]
113        {
114            if is_x86_feature_detected!("sse4.1") {
115                Some(unsafe { Self::force_new(key) })
116            } else {
117                None
118            }
119        }
120
121        #[cfg(not(feature = "std"))]
122        {
123            let _key = key;
124            None
125        }
126    }
127
128    /// Creates a new `SseHash` from a checkpoint while circumventing the runtime check for sse4.1.
129    ///
130    /// # Safety
131    ///
132    /// See [`Self::force_new`] for safety concerns.
133    #[must_use]
134    #[target_feature(enable = "sse4.1")]
135    pub unsafe fn force_from_checkpoint(data: [u8; 164]) -> Self {
136        let portable = PortableHash::from_checkpoint(data);
137        SseHash {
138            v0L: V2x64U::new(portable.v0[1], portable.v0[0]),
139            v0H: V2x64U::new(portable.v0[3], portable.v0[2]),
140            v1L: V2x64U::new(portable.v1[1], portable.v1[0]),
141            v1H: V2x64U::new(portable.v1[3], portable.v1[2]),
142            mul0L: V2x64U::new(portable.mul0[1], portable.mul0[0]),
143            mul0H: V2x64U::new(portable.mul0[3], portable.mul0[2]),
144            mul1L: V2x64U::new(portable.mul1[1], portable.mul1[0]),
145            mul1H: V2x64U::new(portable.mul1[3], portable.mul1[2]),
146            buffer: portable.buffer,
147        }
148    }
149
150    /// Create a new `SseHash` from a checkpoint if the sse4.1 feature is detected
151    #[must_use]
152    pub fn from_checkpoint(data: [u8; 164]) -> Option<Self> {
153        #[cfg(feature = "std")]
154        {
155            if is_x86_feature_detected!("sse4.1") {
156                Some(unsafe { Self::force_from_checkpoint(data) })
157            } else {
158                None
159            }
160        }
161
162        #[cfg(not(feature = "std"))]
163        {
164            let _ = data;
165            None
166        }
167    }
168
169    #[target_feature(enable = "sse4.1")]
170    unsafe fn zipper_merge(v: &V2x64U) -> V2x64U {
171        v.shuffle(&V2x64U::new(0x0708_0609_0D0A_040B, 0x000F_010E_0502_0C03))
172    }
173
174    #[target_feature(enable = "sse4.1")]
175    unsafe fn update(&mut self, (packetH, packetL): (V2x64U, V2x64U)) {
176        self.v1L += packetL;
177        self.v1H += packetH;
178        self.v1L += self.mul0L;
179        self.v1H += self.mul0H;
180        self.mul0L ^= V2x64U(_mm_mul_epu32(self.v1L.0, self.v0L.rotate_by_32().0));
181        self.mul0H ^= V2x64U(_mm_mul_epu32(self.v1H.0, _mm_srli_epi64(self.v0H.0, 32)));
182        self.v0L += self.mul1L;
183        self.v0H += self.mul1H;
184        self.mul1L ^= V2x64U(_mm_mul_epu32(self.v0L.0, self.v1L.rotate_by_32().0));
185        self.mul1H ^= V2x64U(_mm_mul_epu32(self.v0H.0, _mm_srli_epi64(self.v1H.0, 32)));
186        self.v0L += SseHash::zipper_merge(&self.v1L);
187        self.v0H += SseHash::zipper_merge(&self.v1H);
188        self.v1L += SseHash::zipper_merge(&self.v0L);
189        self.v1H += SseHash::zipper_merge(&self.v0H);
190    }
191
192    #[target_feature(enable = "sse4.1")]
193    unsafe fn permute_and_update(&mut self) {
194        let low = self.v0L.rotate_by_32();
195        let high = self.v0H.rotate_by_32();
196        self.update((low, high));
197    }
198
199    #[target_feature(enable = "sse4.1")]
200    pub(crate) unsafe fn finalize64(&mut self) -> u64 {
201        if !self.buffer.is_empty() {
202            self.update_remainder();
203        }
204
205        for _i in 0..4 {
206            self.permute_and_update();
207        }
208
209        let sum0 = self.v0L + self.mul0L;
210        let sum1 = self.v1L + self.mul1L;
211        let hash = sum0 + sum1;
212        let mut result: u64 = 0;
213        _mm_storel_epi64(core::ptr::addr_of_mut!(result).cast::<__m128i>(), hash.0);
214        result
215    }
216
217    #[target_feature(enable = "sse4.1")]
218    pub(crate) unsafe fn finalize128(&mut self) -> [u64; 2] {
219        if !self.buffer.is_empty() {
220            self.update_remainder();
221        }
222
223        for _i in 0..6 {
224            self.permute_and_update();
225        }
226
227        let sum0 = self.v0L + self.mul0L;
228        let sum1 = self.v1H + self.mul1H;
229        let hash = sum0 + sum1;
230        let mut result: [u64; 2] = [0; 2];
231        _mm_storeu_si128(result.as_mut_ptr().cast::<__m128i>(), hash.0);
232        result
233    }
234
235    #[target_feature(enable = "sse4.1")]
236    pub(crate) unsafe fn finalize256(&mut self) -> [u64; 4] {
237        if !self.buffer.is_empty() {
238            self.update_remainder();
239        }
240
241        for _i in 0..10 {
242            self.permute_and_update();
243        }
244
245        let sum0L = self.v0L + self.mul0L;
246        let sum1L = self.v1L + self.mul1L;
247        let sum0H = self.v0H + self.mul0H;
248        let sum1H = self.v1H + self.mul1H;
249        let hashL = SseHash::modular_reduction(&sum1L, &sum0L);
250        let hashH = SseHash::modular_reduction(&sum1H, &sum0H);
251        let mut result: [u64; 4] = [0; 4];
252        let ptr = result.as_mut_ptr().cast::<__m128i>();
253        _mm_storeu_si128(ptr, hashL.0);
254        _mm_storeu_si128(ptr.add(1), hashH.0);
255        result
256    }
257
258    #[target_feature(enable = "sse4.1")]
259    unsafe fn modular_reduction(x: &V2x64U, init: &V2x64U) -> V2x64U {
260        let zero = V2x64U::default();
261        let sign_bit128 = V2x64U::from(_mm_insert_epi32(zero.0, 0x8000_0000_u32 as i32, 3));
262        let top_bits2 = V2x64U::from(_mm_srli_epi64(x.0, 62));
263        let shifted1_unmasked = *x + *x;
264        let top_bits1 = V2x64U::from(_mm_srli_epi64(x.0, 63));
265        let shifted2 = shifted1_unmasked + shifted1_unmasked;
266        let new_low_bits2 = V2x64U::from(_mm_slli_si128(top_bits2.0, 8));
267        let shifted1 = shifted1_unmasked.and_not(&sign_bit128);
268        let new_low_bits1 = V2x64U::from(_mm_slli_si128(top_bits1.0, 8));
269        *init ^ shifted2 ^ new_low_bits2 ^ shifted1 ^ new_low_bits1
270    }
271
272    #[target_feature(enable = "sse4.1")]
273    unsafe fn load_multiple_of_four(bytes: &[u8]) -> V2x64U {
274        let mut data = bytes;
275        let mut mask4 = V2x64U::from(_mm_cvtsi64_si128(0xFFFF_FFFF));
276        let mut ret = if bytes.len() >= 8 {
277            mask4 = V2x64U::from(_mm_slli_si128(mask4.0, 8));
278            data = &bytes[8..];
279            V2x64U::from(_mm_loadl_epi64(bytes.as_ptr().cast::<__m128i>()))
280        } else {
281            V2x64U::new(0, 0)
282        };
283
284        if let Some(d) = data.get(..4) {
285            let last4 = i32::from_le_bytes([d[0], d[1], d[2], d[3]]);
286            let broadcast = V2x64U::from(_mm_set1_epi32(last4));
287            ret |= broadcast & mask4;
288        }
289
290        ret
291    }
292
293    #[target_feature(enable = "sse4.1")]
294    unsafe fn remainder(bytes: &[u8]) -> (V2x64U, V2x64U) {
295        let size_mod32 = bytes.len();
296        let size_mod4 = size_mod32 & 3;
297        if size_mod32 & 16 != 0 {
298            let packetL = V2x64U::from(_mm_loadu_si128(bytes.as_ptr().cast::<__m128i>()));
299            let packett = SseHash::load_multiple_of_four(&bytes[16..]);
300            let remainder = &bytes[(size_mod32 & !3) + size_mod4 - 4..];
301            let last4 =
302                i32::from_le_bytes([remainder[0], remainder[1], remainder[2], remainder[3]]);
303            let packetH = V2x64U::from(_mm_insert_epi32(packett.0, last4, 3));
304            (packetH, packetL)
305        } else {
306            let remainder = &bytes[size_mod32 & !3..];
307            let packetL = SseHash::load_multiple_of_four(bytes);
308            let last4 = unordered_load3(remainder);
309            let packetH = V2x64U::from(_mm_cvtsi64_si128(last4 as i64));
310            (packetH, packetL)
311        }
312    }
313
314    #[target_feature(enable = "sse4.1")]
315    unsafe fn update_remainder(&mut self) {
316        let size = self.buffer.len();
317        let vsize_mod32 = _mm_set1_epi32(size as i32);
318        self.v0L += V2x64U::from(vsize_mod32);
319        self.v0H += V2x64U::from(vsize_mod32);
320        self.rotate_32_by(size as i64);
321        let packet = SseHash::remainder(self.buffer.as_slice());
322        self.update(packet);
323    }
324
325    #[target_feature(enable = "sse4.1")]
326    unsafe fn rotate_32_by(&mut self, count: i64) {
327        let vL = &mut self.v1L;
328        let vH = &mut self.v1H;
329        let count_left = _mm_cvtsi64_si128(count);
330        let count_right = _mm_cvtsi64_si128(32 - count);
331        let shifted_leftL = V2x64U::from(_mm_sll_epi32(vL.0, count_left));
332        let shifted_leftH = V2x64U::from(_mm_sll_epi32(vH.0, count_left));
333        let shifted_rightL = V2x64U::from(_mm_srl_epi32(vL.0, count_right));
334        let shifted_rightH = V2x64U::from(_mm_srl_epi32(vH.0, count_right));
335        *vL = shifted_leftL | shifted_rightL;
336        *vH = shifted_leftH | shifted_rightH;
337    }
338
339    #[inline]
340    #[target_feature(enable = "sse4.1")]
341    unsafe fn data_to_lanes(packet: &[u8]) -> (V2x64U, V2x64U) {
342        let ptr = packet.as_ptr().cast::<__m128i>();
343        let packetL = V2x64U::from(_mm_loadu_si128(ptr));
344        let packetH = V2x64U::from(_mm_loadu_si128(ptr.add(1)));
345
346        (packetH, packetL)
347    }
348
349    #[target_feature(enable = "sse4.1")]
350    unsafe fn append(&mut self, data: &[u8]) {
351        if self.buffer.is_empty() {
352            let mut chunks = data.chunks_exact(PACKET_SIZE);
353            for chunk in chunks.by_ref() {
354                self.update(Self::data_to_lanes(chunk));
355            }
356            self.buffer.set_to(chunks.remainder());
357        } else if let Some(tail) = self.buffer.fill(data) {
358            self.update(Self::data_to_lanes(self.buffer.inner()));
359            let mut chunks = tail.chunks_exact(PACKET_SIZE);
360            for chunk in chunks.by_ref() {
361                self.update(Self::data_to_lanes(chunk));
362            }
363
364            self.buffer.set_to(chunks.remainder());
365        }
366    }
367}
368
369impl_write!(SseHash);
370impl_hasher!(SseHash);
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375
376    #[cfg_attr(miri, ignore)]
377    #[test]
378    fn test_zipper_merge() {
379        unsafe {
380            let x = V2x64U::new(0x0264_432C_CD8A_70E0, 0x0B28_E3EF_EBB3_172D);
381            let y = SseHash::zipper_merge(&x);
382            assert_eq!(y.as_arr(), [0x2D02_1764_E3B3_2CEB, 0x0BE0_2870_438A_EFCD]);
383        }
384    }
385}