Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions library/alloc/src/collections/binary_heap/extract_if.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
use core::iter::FusedIterator;
use core::{fmt, ptr};

use super::BinaryHeap;
use crate::alloc::{Allocator, Global};

/// An iterator which uses a closure to determine if an element should be removed.
///
/// This struct is created by [`BinaryHeap::extract_if`].
/// See its documentation for more.
///
/// # Example
///
/// ```
/// #![feature(binary_heap_extract_if)]
/// use crate::alloc::collections::BinaryHeap;
///
/// let mut heap: BinaryHeap<u32> = (0..128).collect();
/// let iter: Vec<u32> = heap.extract_if(|x| *x % 2 == 0).collect();
#[unstable(feature = "binary_heap_extract_if", issue = "42849")]
Comment thread
Nokel81 marked this conversation as resolved.
Outdated
#[must_use = "iterators are lazy and do nothing unless consumed; \
use `retain_mut` or `extract_if().for_each(drop)` to remove and discard elements"]
pub struct ExtractIf<
'a,
T: Ord,
F,
#[unstable(feature = "allocator_api", issue = "32838")] A: Allocator = Global,
> {
heap: &'a mut BinaryHeap<T, A>,
old_len: usize,
del: usize,
index: usize,
Comment thread
Nokel81 marked this conversation as resolved.
Outdated
predicate: F,
}

impl<T: Ord, F, A: Allocator> ExtractIf<'_, T, F, A> {
pub(super) fn new<'a>(heap: &'a mut BinaryHeap<T, A>, predicate: F) -> ExtractIf<'a, T, F, A> {
// This breaks the heap invariant but we artificially change the length to 0 below and don't change it back until we have fixed this invariant
Comment thread
Nokel81 marked this conversation as resolved.
Outdated
heap.sort_inner_vec();

let old_len = heap.len();
// SAFETY: leak enlargement
Comment thread
Nokel81 marked this conversation as resolved.
Outdated
unsafe { heap.data.set_len(0) };

ExtractIf { heap, predicate, index: 0, old_len, del: 0 }
}
}

#[unstable(feature = "binary_heap_extract_if", issue = "42849")]
impl<T: Ord, F, A: Allocator> Iterator for ExtractIf<'_, T, F, A>
where
F: FnMut(&T) -> bool,
{
type Item = T;

fn next(&mut self) -> Option<Self::Item> {
while self.index < self.old_len {
let i = self.index;
// SAFETY:
// We know that `i < self.end` from the if guard and that `self.end <= self.old_len` from
// the validity of `Self`. Therefore `i` points to an element within `vec`.
//
// Additionally, the i-th element is valid because each element is visited at most once
// and it is the first time we access vec[i].
//
// Note: we can't use `vec.get_unchecked_mut(i)` here since the precondition for that
// function is that i < vec.len(), but we've set vec's length to zero.
let cur = unsafe { &mut *self.heap.data.as_mut_ptr().add(i) };
let extract = (self.predicate)(cur);
// Update the index *after* the predicate is called. If the index
// is updated prior and the predicate panics, the element at this
// index would be leaked.
self.index += 1;
if extract {
self.del += 1;
// SAFETY: We never touch this element again after returning it.
return Some(unsafe { ptr::read(cur) });
} else if self.del > 0 {
// SAFETY: `self.del` > 0, so the hole slot must not overlap with current element.
// We use copy for move, and never touch this element again.
unsafe {
let hole_slot = self.heap.data.as_mut_ptr().add(i - self.del);
ptr::copy_nonoverlapping(cur, hole_slot, 1);
}
}
}
None
}
}

#[unstable(feature = "binary_heap_extract_if", issue = "42849")]
impl<T: Ord, F, A: Allocator> Drop for ExtractIf<'_, T, F, A> {
fn drop(&mut self) {
if self.del > 0 {
// SAFETY: Trailing unchecked items must be valid since we never touch them.
unsafe {
ptr::copy(
self.heap.data.as_ptr().add(self.index),
self.heap.data.as_mut_ptr().add(self.index - self.del),
self.old_len - self.index,
);
}
}
// SAFETY: After filling holes, all items are in contiguous memory.
unsafe {
self.heap.data.set_len(self.old_len - self.del);
}
self.heap.rebuild();
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking through this implementation, it seems to me that we end up replicating Vec::extract_if pretty closely. Maybe we can use it directly, possibly refactoring the impl to allow using it on a passed-in &mut Vec<T, A> for each call (essentially separating the state and the Vec being iterated)?

Alternatively, it's possible that the right pattern is to tell users that the right pattern for extract_if is:

let vec = BinaryHeap::into_sorted_vec(mem::take(heap)); // take only needed if you have only &mut.
vec.extract_if(...);
*heap = BinaryHeap::from(vec);

AFAICT, that is equally efficient to the implementation here, and we can't do better since we give &mut T access to the elements: we have to assume all elements were changed by the pass over them, so we can't rebuild more efficiently than re-sifting the whole heap. (Even though in the common case removals happened but probably no changes to the elements occurred).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had considered trying to directly reuse the Vec::extract_if impl directly, but I was unable to figure out how to allocate an empty Vec<T> in an arbitrary allocator given only a reference to the allocator

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have found a way to reuse the Vec::extract_if implementation, PTAL

}
}

#[unstable(feature = "binary_heap_extract_if", issue = "42849")]
impl<T: Ord, F, A: Allocator> FusedIterator for ExtractIf<'_, T, F, A> where F: FnMut(&T) -> bool {}

#[unstable(feature = "binary_heap_extract_if", issue = "42849")]
impl<T: Ord, F, A> fmt::Debug for ExtractIf<'_, T, F, A>
where
T: fmt::Debug,
A: Allocator,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ExtractIf").finish_non_exhaustive()
}
}
25 changes: 24 additions & 1 deletion library/alloc/src/collections/binary_heap/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,11 @@ use core::num::NonZero;
use core::ops::{Deref, DerefMut};
use core::{fmt, ptr};

#[unstable(feature = "binary_heap_extract_if", issue = "42849")]
pub use self::extract_if::ExtractIf;

mod extract_if;

use crate::alloc::Global;
use crate::collections::TryReserveError;
use crate::slice;
Expand Down Expand Up @@ -772,6 +777,12 @@ impl<T: Ord, A: Allocator> BinaryHeap<T, A> {
#[must_use = "`self` will be dropped if the result is not used"]
#[stable(feature = "binary_heap_extras_15", since = "1.5.0")]
pub fn into_sorted_vec(mut self) -> Vec<T, A> {
self.sort_inner_vec();
self.into_vec()
}

/// Sorts the inner data inplace -> producing an invalid heap. Used for implementing ExtractIf
fn sort_inner_vec(&mut self) {
let mut end = self.len();
while end > 1 {
end -= 1;
Expand All @@ -788,7 +799,6 @@ impl<T: Ord, A: Allocator> BinaryHeap<T, A> {
// Which means 0 < end and end < self.len().
unsafe { self.sift_down_range(0, end) };
}
self.into_vec()
}

// The implementations of sift_up and sift_down use unsafe blocks in
Expand Down Expand Up @@ -1039,6 +1049,19 @@ impl<T: Ord, A: Allocator> BinaryHeap<T, A> {
DrainSorted { inner: self }
}

/// Creates an iterator which uses a closure to determine if an element should be removed.
/// The items are checked in sorted order
///
/// If the closure returns `true`, the element is marked to be removed and yielded
#[unstable(feature = "binary_heap_extract_if", issue = "42849")]
Comment thread
Nokel81 marked this conversation as resolved.
Outdated
#[must_use]
pub fn extract_if<F>(&mut self, predicate: F) -> ExtractIf<'_, T, F, A>
Comment thread
Nokel81 marked this conversation as resolved.
where
F: FnMut(&T) -> bool,
{
ExtractIf::new(self, predicate)
}

/// Retains only the elements specified by the predicate.
///
/// In other words, remove all elements `e` for which `f(&e)` returns
Expand Down
62 changes: 62 additions & 0 deletions library/alloctests/tests/collections/binary_heap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -590,3 +590,65 @@ fn panic_safe() {
}
}
}

#[test]
fn given_a_binary_heap_can_create_an_extract_if_iterator() {
let mut heap = BinaryHeap::new();
let iter = heap.extract_if(|_: &usize| unreachable!("there's nothing to decide on"));

iter.for_each(drop);
assert!(heap.is_empty())
}

#[test]
fn given_some_binary_heap_with_one_item_when_extracting_if_true_extracts_all_items() {
let mut heap = BinaryHeap::new();
heap.push(10);
let v: Vec<usize> = heap.extract_if(|_: &usize| true).collect();

assert!(heap.is_empty());
assert_eq!(v, vec![10]);
}

#[test]
fn given_some_binary_heap_with_three_items_when_extracting_if_true_extracts_all_items_in_sorted_order()
{
let mut heap = BinaryHeap::new();
heap.push(10);
heap.push(15);
heap.push(11);
let v: Vec<_> = heap.extract_if(|_| true).collect();

assert!(heap.is_empty());
assert_eq!(v, vec![10, 11, 15]);
}

#[test]
fn given_some_binary_heap_with_some_items_when_extracting_if_even_extracts_just_even_items() {
let mut heap = BinaryHeap::new();
heap.push(10);
heap.push(15);
heap.push(11);
let v: Vec<_> = heap.extract_if(|&x| x % 2 == 0).collect();

assert_eq!(v, vec![10]);
assert_eq!(heap.pop(), Some(15));
assert_eq!(heap.pop(), Some(11));
assert_eq!(heap.pop(), None);
}

#[test]
fn given_some_binary_heap_with_some_items_when_extracting_if_when_dropping_without_iterating_leaves_heap_in_valid_state()
{
let mut heap = BinaryHeap::new();
heap.push(10);
heap.push(15);
heap.push(11);

drop(heap.extract_if(|&x| x % 2 == 0));

assert_eq!(heap.pop(), Some(15));
assert_eq!(heap.pop(), Some(11));
assert_eq!(heap.pop(), Some(10));
assert_eq!(heap.pop(), None);
}
1 change: 1 addition & 0 deletions library/alloctests/tests/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#![feature(allocator_api)]
#![feature(binary_heap_extract_if)]
#![feature(binary_heap_pop_if)]
#![feature(const_heap)]
#![feature(deque_extend_front)]
Expand Down
Loading