I’ve recently started exploring rustc codebase and noticed
/// Uses a sorted slice `data: &[E]` as a kind of "multi-map". The | |
/// `key_fn` extracts a key of type `K` from the data, and this | |
/// function finds the range of elements that match the key. `data` | |
/// must have been sorted as if by a call to `sort_by_key` for this to | |
/// work. | |
pub fn binary_search_slice<'d, E, K>(data: &'d [E], key_fn: impl Fn(&E) -> K, key: &K) -> &'d [E] | |
where | |
K: Ord, | |
{ | |
let Ok(mid) = data.binary_search_by_key(key, &key_fn) else { | |
return &[]; | |
}; | |
let size = data.len(); | |
// We get back *some* element with the given key -- so do | |
// a galloping search backwards to find the *first* one. | |
let mut start = mid; | |
let mut previous = mid; | |
let mut step = 1; | |
loop { | |
start = start.saturating_sub(step); | |
if start == 0 || key_fn(&data[start]) != *key { | |
break; | |
} | |
previous = start; | |
step *= 2; | |
} | |
step = previous - start; | |
while step > 1 { | |
let half = step / 2; | |
let mid = start + half; | |
if key_fn(&data[mid]) != *key { | |
start = mid; | |
} | |
step -= half; | |
} | |
// adjust by one if we have overshot | |
if start < size && key_fn(&data[start]) != *key { | |
start += 1; | |
} | |
// Now search forward to find the *last* one. | |
let mut end = mid; | |
let mut previous = mid; | |
let mut step = 1; | |
loop { | |
end = end.saturating_add(step).min(size); | |
if end == size || key_fn(&data[end]) != *key { | |
break; | |
} | |
previous = end; | |
step *= 2; | |
} | |
step = end - previous; | |
while step > 1 { | |
let half = step / 2; | |
let mid = end - half; | |
if key_fn(&data[mid]) != *key { | |
end = mid; | |
} | |
step -= half; | |
} | |
&data[start..end] | |
} |
Attentive reader would notice that this function is essentially std::equal_range from C++ but with much more involved implementation. Even with usual std lib uglifiers, C++ LLVM implementation is much easier to understand and reason about:
template <class _AlgPolicy, class _Compare, class _Iter, class _Sent, class _Tp, class _Proj> | |
_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_Iter, _Iter> | |
__equal_range(_Iter __first, _Sent __last, const _Tp& __value, _Compare&& __comp, _Proj&& __proj) { | |
auto __len = _IterOps<_AlgPolicy>::distance(__first, __last); | |
_Iter __end = _IterOps<_AlgPolicy>::next(__first, __last); | |
while (__len != 0) { | |
auto __half_len = std::__half_positive(__len); | |
_Iter __mid = _IterOps<_AlgPolicy>::next(__first, __half_len); | |
if (std::__invoke(__comp, std::__invoke(__proj, *__mid), __value)) { | |
__first = ++__mid; | |
__len -= __half_len + 1; | |
} else if (std::__invoke(__comp, __value, std::__invoke(__proj, *__mid))) { | |
__end = __mid; | |
__len = __half_len; | |
} else { | |
_Iter __mp1 = __mid; | |
return pair<_Iter, _Iter>( | |
std::__lower_bound<_AlgPolicy>(__first, __mid, __value, __comp, __proj), | |
std::__upper_bound<_AlgPolicy>(++__mp1, __end, __value, __comp, __proj)); | |
} | |
} | |
return pair<_Iter, _Iter>(__first, __first); | |
} |
The reason behind its simplicity is the fact that it uses std::lower_bound and std::upper_bound to locate range start and end.
Unfortunately, Rust doesn’t have a lower/upper bound functions, but we can use partition_point to simulate both. Using this function we can rewrite binary_search_slice as follows:
pub fn binary_search_slice_new<'d, E, K>(data: &'d [E], key_fn: impl Fn(&E) -> K, key: &K) -> &'d [E] | |
where | |
K: Ord, | |
{ | |
let size = data.len(); | |
let start = data.partition_point(|x| key_fn(x) < *key); | |
// At this point `start` either points at the first entry with equal or | |
// greater key or is equal to `size` in case all elements have smaller keys | |
// Invariant: start == size || key_fn(&data[start]) >= *key | |
if start == size || key_fn(&data[start]) != *key { | |
return &[]; | |
}; | |
// Invariant: start < size && key_fn(&data[start]) == *key | |
// Find the first entry with key > `key`. Skip `start` entries since | |
// key_fn(&data[start]) == *key | |
// Invariant: offset == size || key_fn(&data[offset]) >= *key | |
let offset = start + 1; | |
let end = data[offset..].partition_point(|x| key_fn(x) <= *key) + offset; | |
// Invariant: end == size || key_fn(&data[end]) > *key | |
&data[start..end] | |
} |
I’ve added comments with invariants so that readers can easily convince themselves in correctness of proposed implementation and it’s still significantly shorter than original implementation. In fact, it’s possible to make it even shorter by skipping the if on line 10 and unconditionally returning &data[start..end], since in case there are no elements with matching key, start and end would be equal making start..end range empty. The reason I decided not to go with this implementation is because it would increase the number of comparisons in the worst case.
Since there is a popular myth that performant code is more complex, some readers would start wondering about performance implications of such change. Let’s find out using the following benchmark
use criterion::{criterion_group, criterion_main, Criterion}; | |
/// Uses a sorted slice `data: &[E]` as a kind of "multi-map". The | |
/// `key_fn` extracts a key of type `K` from the data, and this | |
/// function finds the range of elements that match the key. `data` | |
/// must have been sorted as if by a call to `sort_by_key` for this to | |
/// work. | |
pub fn binary_search_slice<'d, E, K>(data: &'d [E], key_fn: impl Fn(&E) -> K, key: &K) -> &'d [E] | |
where | |
K: Ord, | |
{ | |
let Ok(mid) = data.binary_search_by_key(key, &key_fn) else { | |
return &[]; | |
}; | |
let size = data.len(); | |
// We get back *some* element with the given key -- so do | |
// a galloping search backwards to find the *first* one. | |
let mut start = mid; | |
let mut previous = mid; | |
let mut step = 1; | |
loop { | |
start = start.saturating_sub(step); | |
if start == 0 || key_fn(&data[start]) != *key { | |
break; | |
} | |
previous = start; | |
step *= 2; | |
} | |
step = previous - start; | |
while step > 1 { | |
let half = step / 2; | |
let mid = start + half; | |
if key_fn(&data[mid]) != *key { | |
start = mid; | |
} | |
step -= half; | |
} | |
// adjust by one if we have overshot | |
if start < size && key_fn(&data[start]) != *key { | |
start += 1; | |
} | |
// Now search forward to find the *last* one. | |
let mut end = mid; | |
let mut previous = mid; | |
let mut step = 1; | |
loop { | |
end = end.saturating_add(step).min(size); | |
if end == size || key_fn(&data[end]) != *key { | |
break; | |
} | |
previous = end; | |
step *= 2; | |
} | |
step = end - previous; | |
while step > 1 { | |
let half = step / 2; | |
let mid = end - half; | |
if key_fn(&data[mid]) != *key { | |
end = mid; | |
} | |
step -= half; | |
} | |
&data[start..end] | |
} | |
/// Uses a sorted slice `data: &[E]` as a kind of "multi-map". The | |
/// `key_fn` extracts a key of type `K` from the data, and this | |
/// function finds the range of elements that match the key. `data` | |
/// must have been sorted as if by a call to `sort_by_key` for this to | |
/// work. | |
pub fn binary_search_slice_new<'d, E, K>(data: &'d [E], key_fn: impl Fn(&E) -> K, key: &K) -> &'d [E] | |
where | |
K: Ord, | |
{ | |
let size = data.len(); | |
let start = data.partition_point(|x| key_fn(x) < *key); | |
// At this point `start` either points at the first entry with equal or | |
// greater key or is equal to `size` in case all elements have smaller keys | |
// Invariant: start == size || key_fn(&data[start]) >= *key | |
if start == size || key_fn(&data[start]) != *key { | |
return &[]; | |
}; | |
// Invariant: start < size && key_fn(&data[start]) == *key | |
// Find the first entry with key > `key`. Skip `start` entries since | |
// key_fn(&data[start]) == *key | |
// Invariant: offset == size || key_fn(&data[offset]) >= *key | |
let offset = start + 1; | |
let end = data[offset..].partition_point(|x| key_fn(x) <= *key) + offset; | |
// Invariant: end == size || key_fn(&data[end]) > *key | |
&data[start..end] | |
} | |
fn bench_inserts(c: &mut Criterion) { | |
let mut group = c.benchmark_group("multiply add"); | |
let data = [(0, "zero"), (3, "three-a"), (3, "three-b"), (22, "twenty-two")]; | |
let keys = [-1, 0, 1, 2, 3, 22, 23]; | |
group.bench_function("binary_search_slice", |b| { | |
b.iter(|| { | |
let mut total_len = 0; | |
for key in keys { | |
total_len += binary_search_slice(&data, |x| x.0, &key).len(); | |
} | |
total_len | |
}) | |
}); | |
group.bench_function("binary_search_slice_new", |b| { | |
b.iter(|| { | |
let mut total_len = 0; | |
for key in keys { | |
total_len += binary_search_slice_new(&data, |x| x.0, &key).len(); | |
} | |
total_len | |
}) | |
}); | |
group.finish(); | |
} | |
criterion_group!(benches, bench_inserts); | |
criterion_main!(benches); |
And the results on my M1 macbook air are
suggesting ~20% runtime improvement. Not bad for a code cleanup?
This is yet another reminder to follow Alexander Stepanov’s algorithm design principles and leverage algorithm composition.
P.S.: these suggestions are part of https://github.com/rust-lang/rust/pull/114231