diff --git a/src/ws/mask.rs b/src/ws/mask.rs index 16f0f6b15..2e142d651 100644 --- a/src/ws/mask.rs +++ b/src/ws/mask.rs @@ -1,89 +1,107 @@ //! This is code from [Tungstenite project](https://github.com/snapview/tungstenite-rs) #![cfg_attr(feature = "cargo-clippy", allow(cast_ptr_alignment))] -use std::cmp::min; -use std::mem::uninitialized; +use std::slice; use std::ptr::copy_nonoverlapping; +// Holds a slice guaranteed to be shorter than 8 bytes +struct ShortSlice<'a>(&'a mut [u8]); + +impl<'a> ShortSlice<'a> { + unsafe fn new(slice: &'a mut [u8]) -> Self { + // Sanity check for debug builds + debug_assert!(slice.len() < 8); + ShortSlice(slice) + } + fn len(&self) -> usize { + self.0.len() + } +} + /// Faster version of `apply_mask()` which operates on 8-byte blocks. -/// -/// unsafe because uses pointer math and bit operations for performance #[inline] #[cfg_attr(feature = "cargo-clippy", allow(cast_lossless))] pub(crate) fn apply_mask(buf: &mut [u8], mask_u32: u32) { - unsafe { - let mut ptr = buf.as_mut_ptr(); - let mut len = buf.len(); + // Extend the mask to 64 bits + let mut mask_u64 = ((mask_u32 as u64) << 32) | (mask_u32 as u64); + // Split the buffer into three segments + let (head, mid, tail) = align_buf(buf); - // Possible first unaligned block. - let head = min(len, (8 - (ptr as usize & 0x7)) & 0x3); - let mask_u32 = if head > 0 { - let n = if head > 4 { head - 4 } else { head }; - - let mask_u32 = if n > 0 { - xor_mem(ptr, mask_u32, n); - ptr = ptr.offset(head as isize); - len -= n; - if cfg!(target_endian = "big") { - mask_u32.rotate_left(8 * n as u32) - } else { - mask_u32.rotate_right(8 * n as u32) - } - } else { - mask_u32 - }; - - if head > 4 { - *(ptr as *mut u32) ^= mask_u32; - ptr = ptr.offset(4); - len -= 4; - } - mask_u32 + // Initial unaligned segment + let head_len = head.len(); + if head_len > 0 { + xor_short(head, mask_u64); + if cfg!(target_endian = "big") { + mask_u64 = mask_u64.rotate_left(8 * head_len as u32); } else { - mask_u32 - }; - - if len > 0 { - debug_assert_eq!(ptr as usize % 4, 0); - } - - // Properly aligned middle of the data. - if len >= 8 { - let mut mask_u64 = mask_u32 as u64; - mask_u64 = mask_u64 << 32 | mask_u32 as u64; - - while len >= 8 { - *(ptr as *mut u64) ^= mask_u64; - ptr = ptr.offset(8); - len -= 8; - } - } - - while len >= 4 { - *(ptr as *mut u32) ^= mask_u32; - ptr = ptr.offset(4); - len -= 4; - } - - // Possible last block. - if len > 0 { - xor_mem(ptr, mask_u32, len); + mask_u64 = mask_u64.rotate_right(8 * head_len as u32); } } + // Aligned segment + for v in mid { + *v ^= mask_u64; + } + // Final unaligned segment + if tail.len() > 0 { + xor_short(tail, mask_u64); + } } #[inline] // TODO: copy_nonoverlapping here compiles to call memcpy. While it is not so -// inefficient, it could be done better. The compiler does not see that len is -// limited to 3. -unsafe fn xor_mem(ptr: *mut u8, mask: u32, len: usize) { - let mut b: u32 = uninitialized(); - #[allow(trivial_casts)] - copy_nonoverlapping(ptr, &mut b as *mut _ as *mut u8, len); - b ^= mask; - #[allow(trivial_casts)] - copy_nonoverlapping(&b as *const _ as *const u8, ptr, len); +// inefficient, it could be done better. The compiler does not understand that +// a `ShortSlice` must be smaller than a u64. +fn xor_short(buf: ShortSlice, mask: u64) { + // Unsafe: we know that a `ShortSlice` fits in a u64 + unsafe { + let (ptr, len) = (buf.0.as_mut_ptr(), buf.0.len()); + let mut b: u64 = 0; + #[allow(trivial_casts)] + copy_nonoverlapping(ptr, &mut b as *mut _ as *mut u8, len); + b ^= mask; + #[allow(trivial_casts)] + copy_nonoverlapping(&b as *const _ as *const u8, ptr, len); + } } +#[inline] +// Unsafe: caller must ensure the buffer has the correct size and alignment +unsafe fn cast_slice(buf: &mut [u8]) -> &mut [u64] { + // Assert correct size and alignment in debug builds + debug_assert!(buf.len() & 0x7 == 0); + debug_assert!(buf.as_ptr() as usize & 0x7 == 0); + + slice::from_raw_parts_mut(buf.as_mut_ptr() as *mut u64, buf.len() >> 3) +} + +#[inline] +// Splits a slice into three parts: an unaligned short head and tail, plus an aligned +// u64 mid section. +fn align_buf(buf: &mut [u8]) -> (ShortSlice, &mut [u64], ShortSlice) { + let start_ptr = buf.as_ptr() as usize; + let end_ptr = start_ptr + buf.len(); + + // Round *up* to next aligned boundary for start + let start_aligned = (start_ptr+7) & !0x7; + // Round *down* to last aligned boundary for end + let end_aligned = end_ptr & !0x7; + + if end_aligned >= start_aligned { + // We have our three segments (head, mid, tail) + let (tmp, tail) = buf.split_at_mut(end_aligned - start_ptr); + let (head, mid) = tmp.split_at_mut(start_aligned - start_ptr); + + // Unsafe: we know the middle section is correctly aligned, and the outer + // sections are smaller than 8 bytes + unsafe { (ShortSlice::new(head), cast_slice(mid), ShortSlice(tail)) } + } else { + // We didn't cross even one aligned boundary! + + // Unsafe: The outer sections are smaller than 8 bytes + unsafe { (ShortSlice::new(buf), &mut [], ShortSlice::new(&mut [])) } + } +} + + #[cfg(test)] mod tests { use super::apply_mask;