core/slice/sort/stable/quicksort.rs
1//! This module contains a stable quicksort and partition implementation.
2
3use crate::mem::{ManuallyDrop, MaybeUninit};
4use crate::slice::sort::shared::FreezeMarker;
5use crate::slice::sort::shared::pivot::choose_pivot;
6use crate::slice::sort::shared::smallsort::StableSmallSortTypeImpl;
7use crate::{intrinsics, ptr};
8
9/// Sorts `v` recursively using quicksort.
10/// `scratch.len()` must be at least `max(v.len() - v.len() / 2, SMALL_SORT_GENERAL_SCRATCH_LEN)`
11/// otherwise the implementation may abort.
12///
13/// `limit` when initialized with `c*log(v.len())` for some c ensures we do not
14/// overflow the stack or go quadratic.
15#[inline(never)]
16pub fn quicksort<T, F: FnMut(&T, &T) -> bool>(
17 mut v: &mut [T],
18 scratch: &mut [MaybeUninit<T>],
19 mut limit: u32,
20 mut left_ancestor_pivot: Option<&T>,
21 is_less: &mut F,
22) {
23 loop {
24 let len = v.len();
25
26 if len <= T::small_sort_threshold() {
27 T::small_sort(v, scratch, is_less);
28 return;
29 }
30
31 if limit == 0 {
32 // We have had too many bad pivots, switch to O(n log n) fallback
33 // algorithm. In our case that is driftsort in eager mode.
34 crate::slice::sort::stable::drift::sort(v, scratch, true, is_less);
35 return;
36 }
37 limit -= 1;
38
39 let pivot_pos = choose_pivot(v, is_less);
40
41 // SAFETY: We only access the temporary copy for Freeze types, otherwise
42 // self-modifications via `is_less` would not be observed and this would
43 // be unsound. Our temporary copy does not escape this scope.
44 let pivot_copy = unsafe { ManuallyDrop::new(ptr::read(&v[pivot_pos])) };
45 let pivot_ref = (!has_direct_interior_mutability::<T>()).then_some(&*pivot_copy);
46
47 // We choose a pivot, and check if this pivot is equal to our left
48 // ancestor. If true, we do a partition putting equal elements on the
49 // left and do not recurse on it. This gives O(n log k) sorting for k
50 // distinct values, a strategy borrowed from pdqsort. For types with
51 // interior mutability we can't soundly create a temporary copy of the
52 // ancestor pivot, and use left_partition_len == 0 as our method for
53 // detecting when we re-use a pivot, which means we do at most three
54 // partition operations with pivot p instead of the optimal two.
55 let mut perform_equal_partition = false;
56 if let Some(la_pivot) = left_ancestor_pivot {
57 perform_equal_partition = !is_less(la_pivot, &v[pivot_pos]);
58 }
59
60 let mut left_partition_len = 0;
61 if !perform_equal_partition {
62 left_partition_len = stable_partition(v, scratch, pivot_pos, false, is_less);
63 perform_equal_partition = left_partition_len == 0;
64 }
65
66 if perform_equal_partition {
67 let mid_eq = stable_partition(v, scratch, pivot_pos, true, &mut |a, b| !is_less(b, a));
68 v = &mut v[mid_eq..];
69 left_ancestor_pivot = None;
70 continue;
71 }
72
73 // Process left side with the next loop iter, right side with recursion.
74 let (left, right) = v.split_at_mut(left_partition_len);
75 quicksort(right, scratch, limit, pivot_ref, is_less);
76 v = left;
77 }
78}
79
80/// Partitions `v` using pivot `p = v[pivot_pos]` and returns the number of
81/// elements less than `p`. The relative order of elements that compare < p and
82/// those that compare >= p is preserved - it is a stable partition.
83///
84/// If `is_less` is not a strict total order or panics, `scratch.len() < v.len()`,
85/// or `pivot_pos >= v.len()`, the result and `v`'s state is sound but unspecified.
86fn stable_partition<T, F: FnMut(&T, &T) -> bool>(
87 v: &mut [T],
88 scratch: &mut [MaybeUninit<T>],
89 pivot_pos: usize,
90 pivot_goes_left: bool,
91 is_less: &mut F,
92) -> usize {
93 let len = v.len();
94
95 if intrinsics::unlikely(scratch.len() < len || pivot_pos >= len) {
96 core::intrinsics::abort()
97 }
98
99 let v_base = v.as_ptr();
100 let scratch_base = MaybeUninit::slice_as_mut_ptr(scratch);
101
102 // The core idea is to write the values that compare as less-than to the left
103 // side of `scratch`, while the values that compared as greater or equal than
104 // `v[pivot_pos]` go to the right side of `scratch` in reverse. See
105 // PartitionState for details.
106
107 // SAFETY: see individual comments.
108 unsafe {
109 // SAFETY: we made sure the scratch has length >= len and that pivot_pos
110 // is in-bounds. v and scratch are disjoint slices.
111 let pivot = v_base.add(pivot_pos);
112 let mut state = PartitionState::new(v_base, scratch_base, len);
113
114 let mut pivot_in_scratch = ptr::null_mut();
115 let mut loop_end_pos = pivot_pos;
116
117 // SAFETY: this loop is equivalent to calling state.partition_one
118 // exactly len times.
119 loop {
120 // Ideally the outer loop won't be unrolled, to save binary size,
121 // but we do want the inner loop to be unrolled for small types, as
122 // this gave significant performance boosts in benchmarks. Unrolling
123 // through for _ in 0..UNROLL_LEN { .. } instead of manually improves
124 // compile times but has a ~10-20% performance penalty on opt-level=s.
125 if const { size_of::<T>() <= 16 } {
126 const UNROLL_LEN: usize = 4;
127 let unroll_end = v_base.add(loop_end_pos.saturating_sub(UNROLL_LEN - 1));
128 while state.scan < unroll_end {
129 state.partition_one(is_less(&*state.scan, &*pivot));
130 state.partition_one(is_less(&*state.scan, &*pivot));
131 state.partition_one(is_less(&*state.scan, &*pivot));
132 state.partition_one(is_less(&*state.scan, &*pivot));
133 }
134 }
135
136 let loop_end = v_base.add(loop_end_pos);
137 while state.scan < loop_end {
138 state.partition_one(is_less(&*state.scan, &*pivot));
139 }
140
141 if loop_end_pos == len {
142 break;
143 }
144
145 // We avoid comparing pivot with itself, as this could create deadlocks for
146 // certain comparison operators. We also store its location later for later.
147 pivot_in_scratch = state.partition_one(pivot_goes_left);
148
149 loop_end_pos = len;
150 }
151
152 // `pivot` must be copied into its correct position again, because a
153 // comparison operator might have modified it.
154 if has_direct_interior_mutability::<T>() {
155 ptr::copy_nonoverlapping(pivot, pivot_in_scratch, 1);
156 }
157
158 // SAFETY: partition_one being called exactly len times guarantees that scratch
159 // is initialized with a permuted copy of `v`, and that num_left <= v.len().
160 // Copying scratch[0..num_left] and scratch[num_left..v.len()] back is thus
161 // sound, as the values in scratch will never be read again, meaning our copies
162 // semantically act as moves, permuting `v`.
163
164 // Copy all the elements < p directly from swap to v.
165 let v_base = v.as_mut_ptr();
166 ptr::copy_nonoverlapping(scratch_base, v_base, state.num_left);
167
168 // Copy the elements >= p in reverse order.
169 for i in 0..len - state.num_left {
170 ptr::copy_nonoverlapping(
171 scratch_base.add(len - 1 - i),
172 v_base.add(state.num_left + i),
173 1,
174 );
175 }
176
177 state.num_left
178 }
179}
180
181struct PartitionState<T> {
182 // The start of the scratch auxiliary memory.
183 scratch_base: *mut T,
184 // The current element that is being looked at, scans left to right through slice.
185 scan: *const T,
186 // Counts the number of elements that went to the left side, also works around:
187 // https://github.com/rust-lang/rust/issues/117128
188 num_left: usize,
189 // Reverse scratch output pointer.
190 scratch_rev: *mut T,
191}
192
193impl<T> PartitionState<T> {
194 /// # Safety
195 ///
196 /// `scan` and `scratch` must point to valid disjoint buffers of length `len`. The
197 /// scan buffer must be initialized.
198 unsafe fn new(scan: *const T, scratch: *mut T, len: usize) -> Self {
199 // SAFETY: See function safety comment.
200 unsafe { Self { scratch_base: scratch, scan, num_left: 0, scratch_rev: scratch.add(len) } }
201 }
202
203 /// Depending on the value of `towards_left` this function will write a value
204 /// to the growing left or right side of the scratch memory. This forms the
205 /// branchless core of the partition.
206 ///
207 /// # Safety
208 ///
209 /// This function may be called at most `len` times. If it is called exactly
210 /// `len` times the scratch buffer then contains a copy of each element from
211 /// the scan buffer exactly once - a permutation, and num_left <= len.
212 unsafe fn partition_one(&mut self, towards_left: bool) -> *mut T {
213 // SAFETY: see individual comments.
214 unsafe {
215 // SAFETY: in-bounds because this function is called at most len times, and thus
216 // right now is incremented at most len - 1 times. Similarly, num_left < len and
217 // num_right < len, where num_right == i - num_left at the start of the ith
218 // iteration (zero-indexed).
219 self.scratch_rev = self.scratch_rev.sub(1);
220
221 // SAFETY: now we have scratch_rev == base + len - (i + 1). This means
222 // scratch_rev + num_left == base + len - 1 - num_right < base + len.
223 let dst_base = if towards_left { self.scratch_base } else { self.scratch_rev };
224 let dst = dst_base.add(self.num_left);
225 ptr::copy_nonoverlapping(self.scan, dst, 1);
226
227 self.num_left += towards_left as usize;
228 self.scan = self.scan.add(1);
229 dst
230 }
231 }
232}
233
234trait IsFreeze {
235 fn is_freeze() -> bool;
236}
237
238impl<T> IsFreeze for T {
239 default fn is_freeze() -> bool {
240 false
241 }
242}
243impl<T: FreezeMarker> IsFreeze for T {
244 fn is_freeze() -> bool {
245 true
246 }
247}
248
249#[must_use]
250fn has_direct_interior_mutability<T>() -> bool {
251 // If a type has interior mutability it may alter itself during comparison
252 // in a way that must be preserved after the sort operation concludes.
253 // Otherwise a type like Mutex<Option<Box<str>>> could lead to double free.
254 !T::is_freeze()
255}