core/slice/
rotate.rs

1use crate::mem::{MaybeUninit, SizedTypeProperties};
2use crate::ptr;
3
4type BufType = [usize; 32];
5
6/// Rotates the range `[mid-left, mid+right)` such that the element at `mid` becomes the first
7/// element. Equivalently, rotates the range `left` elements to the left or `right` elements to the
8/// right.
9///
10/// # Safety
11///
12/// The specified range must be valid for reading and writing.
13#[inline]
14pub(super) const unsafe fn ptr_rotate<T>(left: usize, mid: *mut T, right: usize) {
15    if T::IS_ZST {
16        return;
17    }
18    // abort early if the rotate is a no-op
19    if (left == 0) || (right == 0) {
20        return;
21    }
22    // `T` is not a zero-sized type, so it's okay to divide by its size.
23    if !cfg!(feature = "optimize_for_size")
24        // FIXME(const-hack): Use cmp::min when available in const
25        && const_min(left, right) <= size_of::<BufType>() / size_of::<T>()
26    {
27        // SAFETY: guaranteed by the caller
28        unsafe { ptr_rotate_memmove(left, mid, right) };
29    } else if !cfg!(feature = "optimize_for_size")
30        && ((left + right < 24) || (size_of::<T>() > size_of::<[usize; 4]>()))
31    {
32        // SAFETY: guaranteed by the caller
33        unsafe { ptr_rotate_gcd(left, mid, right) }
34    } else {
35        // SAFETY: guaranteed by the caller
36        unsafe { ptr_rotate_swap(left, mid, right) }
37    }
38}
39
40/// Algorithm 1 is used if `min(left, right)` is small enough to fit onto a stack buffer. The
41/// `min(left, right)` elements are copied onto the buffer, `memmove` is applied to the others, and
42/// the ones on the buffer are moved back into the hole on the opposite side of where they
43/// originated.
44///
45/// # Safety
46///
47/// The specified range must be valid for reading and writing.
48#[inline]
49const unsafe fn ptr_rotate_memmove<T>(left: usize, mid: *mut T, right: usize) {
50    // The `[T; 0]` here is to ensure this is appropriately aligned for T
51    let mut rawarray = MaybeUninit::<(BufType, [T; 0])>::uninit();
52    let buf = rawarray.as_mut_ptr() as *mut T;
53    // SAFETY: `mid-left <= mid-left+right < mid+right`
54    let dim = unsafe { mid.sub(left).add(right) };
55    if left <= right {
56        // SAFETY:
57        //
58        // 1) The `if` condition about the sizes ensures `[mid-left; left]` will fit in
59        //    `buf` without overflow and `buf` was created just above and so cannot be
60        //    overlapped with any value of `[mid-left; left]`
61        // 2) [mid-left, mid+right) are all valid for reading and writing and we don't care
62        //    about overlaps here.
63        // 3) The `if` condition about `left <= right` ensures writing `left` elements to
64        //    `dim = mid-left+right` is valid because:
65        //    - `buf` is valid and `left` elements were written in it in 1)
66        //    - `dim+left = mid-left+right+left = mid+right` and we write `[dim, dim+left)`
67        unsafe {
68            // 1)
69            ptr::copy_nonoverlapping(mid.sub(left), buf, left);
70            // 2)
71            ptr::copy(mid, mid.sub(left), right);
72            // 3)
73            ptr::copy_nonoverlapping(buf, dim, left);
74        }
75    } else {
76        // SAFETY: same reasoning as above but with `left` and `right` reversed
77        unsafe {
78            ptr::copy_nonoverlapping(mid, buf, right);
79            ptr::copy(mid.sub(left), dim, left);
80            ptr::copy_nonoverlapping(buf, mid.sub(left), right);
81        }
82    }
83}
84
85/// Algorithm 2 is used for small values of `left + right` or for large `T`. The elements
86/// are moved into their final positions one at a time starting at `mid - left` and advancing by
87/// `right` steps modulo `left + right`, such that only one temporary is needed. Eventually, we
88/// arrive back at `mid - left`. However, if `gcd(left + right, right)` is not 1, the above steps
89/// skipped over elements. For example:
90/// ```text
91/// left = 10, right = 6
92/// the `^` indicates an element in its final place
93/// 6 7 8 9 10 11 12 13 14 15 . 0 1 2 3 4 5
94/// after using one step of the above algorithm (The X will be overwritten at the end of the round,
95/// and 12 is stored in a temporary):
96/// X 7 8 9 10 11 6 13 14 15 . 0 1 2 3 4 5
97///               ^
98/// after using another step (now 2 is in the temporary):
99/// X 7 8 9 10 11 6 13 14 15 . 0 1 12 3 4 5
100///               ^                 ^
101/// after the third step (the steps wrap around, and 8 is in the temporary):
102/// X 7 2 9 10 11 6 13 14 15 . 0 1 12 3 4 5
103///     ^         ^                 ^
104/// after 7 more steps, the round ends with the temporary 0 getting put in the X:
105/// 0 7 2 9 4 11 6 13 8 15 . 10 1 12 3 14 5
106/// ^   ^   ^    ^    ^       ^    ^    ^
107/// ```
108/// Fortunately, the number of skipped over elements between finalized elements is always equal, so
109/// we can just offset our starting position and do more rounds (the total number of rounds is the
110/// `gcd(left + right, right)` value). The end result is that all elements are finalized once and
111/// only once.
112///
113/// Algorithm 2 can be vectorized by chunking and performing many rounds at once, but there are too
114/// few rounds on average until `left + right` is enormous, and the worst case of a single
115/// round is always there.
116///
117/// # Safety
118///
119/// The specified range must be valid for reading and writing.
120#[inline]
121const unsafe fn ptr_rotate_gcd<T>(left: usize, mid: *mut T, right: usize) {
122    // Algorithm 2
123    // Microbenchmarks indicate that the average performance for random shifts is better all
124    // the way until about `left + right == 32`, but the worst case performance breaks even
125    // around 16. 24 was chosen as middle ground. If the size of `T` is larger than 4
126    // `usize`s, this algorithm also outperforms other algorithms.
127    // SAFETY: callers must ensure `mid - left` is valid for reading and writing.
128    let x = unsafe { mid.sub(left) };
129    // beginning of first round
130    // SAFETY: see previous comment.
131    let mut tmp: T = unsafe { x.read() };
132    let mut i = right;
133    // `gcd` can be found before hand by calculating `gcd(left + right, right)`,
134    // but it is faster to do one loop which calculates the gcd as a side effect, then
135    // doing the rest of the chunk
136    let mut gcd = right;
137    // benchmarks reveal that it is faster to swap temporaries all the way through instead
138    // of reading one temporary once, copying backwards, and then writing that temporary at
139    // the very end. This is possibly due to the fact that swapping or replacing temporaries
140    // uses only one memory address in the loop instead of needing to manage two.
141    loop {
142        // [long-safety-expl]
143        // SAFETY: callers must ensure `[left, left+mid+right)` are all valid for reading and
144        // writing.
145        //
146        // - `i` start with `right` so `mid-left <= x+i = x+right = mid-left+right < mid+right`
147        // - `i <= left+right-1` is always true
148        //   - if `i < left`, `right` is added so `i < left+right` and on the next
149        //     iteration `left` is removed from `i` so it doesn't go further
150        //   - if `i >= left`, `left` is removed immediately and so it doesn't go further.
151        // - overflows cannot happen for `i` since the function's safety contract ask for
152        //   `mid+right-1 = x+left+right` to be valid for writing
153        // - underflows cannot happen because `i` must be bigger or equal to `left` for
154        //   a subtraction of `left` to happen.
155        //
156        // So `x+i` is valid for reading and writing if the caller respected the contract
157        tmp = unsafe { x.add(i).replace(tmp) };
158        // instead of incrementing `i` and then checking if it is outside the bounds, we
159        // check if `i` will go outside the bounds on the next increment. This prevents
160        // any wrapping of pointers or `usize`.
161        if i >= left {
162            i -= left;
163            if i == 0 {
164                // end of first round
165                // SAFETY: tmp has been read from a valid source and x is valid for writing
166                // according to the caller.
167                unsafe { x.write(tmp) };
168                break;
169            }
170            // this conditional must be here if `left + right >= 15`
171            if i < gcd {
172                gcd = i;
173            }
174        } else {
175            i += right;
176        }
177    }
178    // finish the chunk with more rounds
179    // FIXME(const-hack): Use `for start in 1..gcd` when available in const
180    let mut start = 1;
181    while start < gcd {
182        // SAFETY: `gcd` is at most equal to `right` so all values in `1..gcd` are valid for
183        // reading and writing as per the function's safety contract, see [long-safety-expl]
184        // above
185        tmp = unsafe { x.add(start).read() };
186        // [safety-expl-addition]
187        //
188        // Here `start < gcd` so `start < right` so `i < right+right`: `right` being the
189        // greatest common divisor of `(left+right, right)` means that `left = right` so
190        // `i < left+right` so `x+i = mid-left+i` is always valid for reading and writing
191        // according to the function's safety contract.
192        i = start + right;
193        loop {
194            // SAFETY: see [long-safety-expl] and [safety-expl-addition]
195            tmp = unsafe { x.add(i).replace(tmp) };
196            if i >= left {
197                i -= left;
198                if i == start {
199                    // SAFETY: see [long-safety-expl] and [safety-expl-addition]
200                    unsafe { x.add(start).write(tmp) };
201                    break;
202                }
203            } else {
204                i += right;
205            }
206        }
207
208        start += 1;
209    }
210}
211
212/// Algorithm 3 utilizes repeated swapping of `min(left, right)` elements.
213///
214/// ///
215/// ```text
216/// left = 11, right = 4
217/// [4 5 6 7 8 9 10 11 12 13 14 . 0 1 2 3]
218///                  ^  ^  ^  ^   ^ ^ ^ ^ swapping the right most elements with elements to the left
219/// [4 5 6 7 8 9 10 . 0 1 2 3] 11 12 13 14
220///        ^ ^ ^  ^   ^ ^ ^ ^ swapping these
221/// [4 5 6 . 0 1 2 3] 7 8 9 10 11 12 13 14
222/// we cannot swap any more, but a smaller rotation problem is left to solve
223/// ```
224/// when `left < right` the swapping happens from the left instead.
225///
226/// # Safety
227///
228/// The specified range must be valid for reading and writing.
229#[inline]
230const unsafe fn ptr_rotate_swap<T>(mut left: usize, mut mid: *mut T, mut right: usize) {
231    loop {
232        if left >= right {
233            // Algorithm 3
234            // There is an alternate way of swapping that involves finding where the last swap
235            // of this algorithm would be, and swapping using that last chunk instead of swapping
236            // adjacent chunks like this algorithm is doing, but this way is still faster.
237            loop {
238                // SAFETY:
239                // `left >= right` so `[mid-right, mid+right)` is valid for reading and writing
240                // Subtracting `right` from `mid` each turn is counterbalanced by the addition and
241                // check after it.
242                unsafe {
243                    ptr::swap_nonoverlapping(mid.sub(right), mid, right);
244                    mid = mid.sub(right);
245                }
246                left -= right;
247                if left < right {
248                    break;
249                }
250            }
251        } else {
252            // Algorithm 3, `left < right`
253            loop {
254                // SAFETY: `[mid-left, mid+left)` is valid for reading and writing because
255                // `left < right` so `mid+left < mid+right`.
256                // Adding `left` to `mid` each turn is counterbalanced by the subtraction and check
257                // after it.
258                unsafe {
259                    ptr::swap_nonoverlapping(mid.sub(left), mid, left);
260                    mid = mid.add(left);
261                }
262                right -= left;
263                if right < left {
264                    break;
265                }
266            }
267        }
268        if (right == 0) || (left == 0) {
269            return;
270        }
271    }
272}
273
274// FIXME(const-hack): Use cmp::min when available in const
275const fn const_min(left: usize, right: usize) -> usize {
276    if right < left { right } else { left }
277}