mirror of
https://github.com/actix/actix-web.git
synced 2024-12-30 12:00:38 +00:00
Refactor apply_mask
implementation, removing dead code paths and reducing scope of unsafety
This commit is contained in:
parent
85012f947a
commit
87824a9cf6
1 changed files with 86 additions and 68 deletions
138
src/ws/mask.rs
138
src/ws/mask.rs
|
@ -1,89 +1,107 @@
|
|||
//! This is code from [Tungstenite project](https://github.com/snapview/tungstenite-rs)
|
||||
#![cfg_attr(feature = "cargo-clippy", allow(cast_ptr_alignment))]
|
||||
use std::cmp::min;
|
||||
use std::mem::uninitialized;
|
||||
use std::slice;
|
||||
use std::ptr::copy_nonoverlapping;
|
||||
|
||||
// Holds a slice guaranteed to be shorter than 8 bytes
|
||||
struct ShortSlice<'a>(&'a mut [u8]);
|
||||
|
||||
impl<'a> ShortSlice<'a> {
|
||||
unsafe fn new(slice: &'a mut [u8]) -> Self {
|
||||
// Sanity check for debug builds
|
||||
debug_assert!(slice.len() < 8);
|
||||
ShortSlice(slice)
|
||||
}
|
||||
fn len(&self) -> usize {
|
||||
self.0.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Faster version of `apply_mask()` which operates on 8-byte blocks.
|
||||
///
|
||||
/// unsafe because uses pointer math and bit operations for performance
|
||||
#[inline]
|
||||
#[cfg_attr(feature = "cargo-clippy", allow(cast_lossless))]
|
||||
pub(crate) fn apply_mask(buf: &mut [u8], mask_u32: u32) {
|
||||
unsafe {
|
||||
let mut ptr = buf.as_mut_ptr();
|
||||
let mut len = buf.len();
|
||||
// Extend the mask to 64 bits
|
||||
let mut mask_u64 = ((mask_u32 as u64) << 32) | (mask_u32 as u64);
|
||||
// Split the buffer into three segments
|
||||
let (head, mid, tail) = align_buf(buf);
|
||||
|
||||
// Possible first unaligned block.
|
||||
let head = min(len, (8 - (ptr as usize & 0x7)) & 0x3);
|
||||
let mask_u32 = if head > 0 {
|
||||
let n = if head > 4 { head - 4 } else { head };
|
||||
|
||||
let mask_u32 = if n > 0 {
|
||||
xor_mem(ptr, mask_u32, n);
|
||||
ptr = ptr.offset(head as isize);
|
||||
len -= n;
|
||||
// Initial unaligned segment
|
||||
let head_len = head.len();
|
||||
if head_len > 0 {
|
||||
xor_short(head, mask_u64);
|
||||
if cfg!(target_endian = "big") {
|
||||
mask_u32.rotate_left(8 * n as u32)
|
||||
mask_u64 = mask_u64.rotate_left(8 * head_len as u32);
|
||||
} else {
|
||||
mask_u32.rotate_right(8 * n as u32)
|
||||
}
|
||||
} else {
|
||||
mask_u32
|
||||
};
|
||||
|
||||
if head > 4 {
|
||||
*(ptr as *mut u32) ^= mask_u32;
|
||||
ptr = ptr.offset(4);
|
||||
len -= 4;
|
||||
}
|
||||
mask_u32
|
||||
} else {
|
||||
mask_u32
|
||||
};
|
||||
|
||||
if len > 0 {
|
||||
debug_assert_eq!(ptr as usize % 4, 0);
|
||||
}
|
||||
|
||||
// Properly aligned middle of the data.
|
||||
if len >= 8 {
|
||||
let mut mask_u64 = mask_u32 as u64;
|
||||
mask_u64 = mask_u64 << 32 | mask_u32 as u64;
|
||||
|
||||
while len >= 8 {
|
||||
*(ptr as *mut u64) ^= mask_u64;
|
||||
ptr = ptr.offset(8);
|
||||
len -= 8;
|
||||
mask_u64 = mask_u64.rotate_right(8 * head_len as u32);
|
||||
}
|
||||
}
|
||||
|
||||
while len >= 4 {
|
||||
*(ptr as *mut u32) ^= mask_u32;
|
||||
ptr = ptr.offset(4);
|
||||
len -= 4;
|
||||
}
|
||||
|
||||
// Possible last block.
|
||||
if len > 0 {
|
||||
xor_mem(ptr, mask_u32, len);
|
||||
// Aligned segment
|
||||
for v in mid {
|
||||
*v ^= mask_u64;
|
||||
}
|
||||
// Final unaligned segment
|
||||
if tail.len() > 0 {
|
||||
xor_short(tail, mask_u64);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
// TODO: copy_nonoverlapping here compiles to call memcpy. While it is not so
|
||||
// inefficient, it could be done better. The compiler does not see that len is
|
||||
// limited to 3.
|
||||
unsafe fn xor_mem(ptr: *mut u8, mask: u32, len: usize) {
|
||||
let mut b: u32 = uninitialized();
|
||||
// inefficient, it could be done better. The compiler does not understand that
|
||||
// a `ShortSlice` must be smaller than a u64.
|
||||
fn xor_short(buf: ShortSlice, mask: u64) {
|
||||
// Unsafe: we know that a `ShortSlice` fits in a u64
|
||||
unsafe {
|
||||
let (ptr, len) = (buf.0.as_mut_ptr(), buf.0.len());
|
||||
let mut b: u64 = 0;
|
||||
#[allow(trivial_casts)]
|
||||
copy_nonoverlapping(ptr, &mut b as *mut _ as *mut u8, len);
|
||||
b ^= mask;
|
||||
#[allow(trivial_casts)]
|
||||
copy_nonoverlapping(&b as *const _ as *const u8, ptr, len);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
// Unsafe: caller must ensure the buffer has the correct size and alignment
|
||||
unsafe fn cast_slice(buf: &mut [u8]) -> &mut [u64] {
|
||||
// Assert correct size and alignment in debug builds
|
||||
debug_assert!(buf.len() & 0x7 == 0);
|
||||
debug_assert!(buf.as_ptr() as usize & 0x7 == 0);
|
||||
|
||||
slice::from_raw_parts_mut(buf.as_mut_ptr() as *mut u64, buf.len() >> 3)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
// Splits a slice into three parts: an unaligned short head and tail, plus an aligned
|
||||
// u64 mid section.
|
||||
fn align_buf(buf: &mut [u8]) -> (ShortSlice, &mut [u64], ShortSlice) {
|
||||
let start_ptr = buf.as_ptr() as usize;
|
||||
let end_ptr = start_ptr + buf.len();
|
||||
|
||||
// Round *up* to next aligned boundary for start
|
||||
let start_aligned = (start_ptr+7) & !0x7;
|
||||
// Round *down* to last aligned boundary for end
|
||||
let end_aligned = end_ptr & !0x7;
|
||||
|
||||
if end_aligned >= start_aligned {
|
||||
// We have our three segments (head, mid, tail)
|
||||
let (tmp, tail) = buf.split_at_mut(end_aligned - start_ptr);
|
||||
let (head, mid) = tmp.split_at_mut(start_aligned - start_ptr);
|
||||
|
||||
// Unsafe: we know the middle section is correctly aligned, and the outer
|
||||
// sections are smaller than 8 bytes
|
||||
unsafe { (ShortSlice::new(head), cast_slice(mid), ShortSlice(tail)) }
|
||||
} else {
|
||||
// We didn't cross even one aligned boundary!
|
||||
|
||||
// Unsafe: The outer sections are smaller than 8 bytes
|
||||
unsafe { (ShortSlice::new(buf), &mut [], ShortSlice::new(&mut [])) }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::apply_mask;
|
||||
|
|
Loading…
Reference in a new issue