use super::plumbing::*;
use super::*;
use rayon_core::join;
use std::iter;
/// `Chain` is an iterator that joins `b` after `a` in one continuous iterator.
/// This struct is created by the [`chain()`] method on [`ParallelIterator`]
///
/// [`chain()`]: ParallelIterator::chain()
#[must_use = "iterator adaptors are lazy and do nothing unless consumed"]
#[derive(Debug, Clone)]
pub struct Chain {
a: A,
b: B,
}
impl Chain {
/// Creates a new `Chain` iterator.
pub(super) fn new(a: A, b: B) -> Self {
Chain { a, b }
}
}
impl ParallelIterator for Chain
where
A: ParallelIterator,
B: ParallelIterator- ,
{
type Item = A::Item;
fn drive_unindexed(self, consumer: C) -> C::Result
where
C: UnindexedConsumer,
{
let Chain { a, b } = self;
// If we returned a value from our own `opt_len`, then the collect consumer in particular
// will balk at being treated like an actual `UnindexedConsumer`. But when we do know the
// length, we can use `Consumer::split_at` instead, and this is still harmless for other
// truly-unindexed consumers too.
let (left, right, reducer) = if let Some(len) = a.opt_len() {
consumer.split_at(len)
} else {
let reducer = consumer.to_reducer();
(consumer.split_off_left(), consumer, reducer)
};
let (a, b) = join(|| a.drive_unindexed(left), || b.drive_unindexed(right));
reducer.reduce(a, b)
}
fn opt_len(&self) -> Option {
self.a.opt_len()?.checked_add(self.b.opt_len()?)
}
}
impl IndexedParallelIterator for Chain
where
A: IndexedParallelIterator,
B: IndexedParallelIterator
- ,
{
fn drive(self, consumer: C) -> C::Result
where
C: Consumer,
{
let Chain { a, b } = self;
let (left, right, reducer) = consumer.split_at(a.len());
let (a, b) = join(|| a.drive(left), || b.drive(right));
reducer.reduce(a, b)
}
fn len(&self) -> usize {
self.a.len().checked_add(self.b.len()).expect("overflow")
}
fn with_producer(self, callback: CB) -> CB::Output
where
CB: ProducerCallback,
{
let a_len = self.a.len();
return self.a.with_producer(CallbackA {
callback,
a_len,
b: self.b,
});
struct CallbackA {
callback: CB,
a_len: usize,
b: B,
}
impl ProducerCallback for CallbackA
where
B: IndexedParallelIterator,
CB: ProducerCallback,
{
type Output = CB::Output;
fn callback(self, a_producer: A) -> Self::Output
where
A: Producer
- ,
{
self.b.with_producer(CallbackB {
callback: self.callback,
a_len: self.a_len,
a_producer,
})
}
}
struct CallbackB {
callback: CB,
a_len: usize,
a_producer: A,
}
impl ProducerCallback for CallbackB
where
A: Producer,
CB: ProducerCallback,
{
type Output = CB::Output;
fn callback(self, b_producer: B) -> Self::Output
where
B: Producer
- ,
{
let producer = ChainProducer::new(self.a_len, self.a_producer, b_producer);
self.callback.callback(producer)
}
}
}
}
// ////////////////////////////////////////////////////////////////////////
struct ChainProducer
where
A: Producer,
B: Producer
- ,
{
a_len: usize,
a: A,
b: B,
}
impl ChainProducer
where
A: Producer,
B: Producer
- ,
{
fn new(a_len: usize, a: A, b: B) -> Self {
ChainProducer { a_len, a, b }
}
}
impl Producer for ChainProducer
where
A: Producer,
B: Producer
- ,
{
type Item = A::Item;
type IntoIter = ChainSeq;
fn into_iter(self) -> Self::IntoIter {
ChainSeq::new(self.a.into_iter(), self.b.into_iter())
}
fn min_len(&self) -> usize {
Ord::max(self.a.min_len(), self.b.min_len())
}
fn max_len(&self) -> usize {
Ord::min(self.a.max_len(), self.b.max_len())
}
fn split_at(self, index: usize) -> (Self, Self) {
if index <= self.a_len {
let a_rem = self.a_len - index;
let (a_left, a_right) = self.a.split_at(index);
let (b_left, b_right) = self.b.split_at(0);
(
ChainProducer::new(index, a_left, b_left),
ChainProducer::new(a_rem, a_right, b_right),
)
} else {
let (a_left, a_right) = self.a.split_at(self.a_len);
let (b_left, b_right) = self.b.split_at(index - self.a_len);
(
ChainProducer::new(self.a_len, a_left, b_left),
ChainProducer::new(0, a_right, b_right),
)
}
}
fn fold_with(self, mut folder: F) -> F
where
F: Folder,
{
folder = self.a.fold_with(folder);
if folder.full() {
folder
} else {
self.b.fold_with(folder)
}
}
}
// ////////////////////////////////////////////////////////////////////////
/// Wrapper for `Chain` to implement `ExactSizeIterator`
struct ChainSeq {
chain: iter::Chain,
}
impl ChainSeq {
fn new(a: A, b: B) -> ChainSeq
where
A: ExactSizeIterator,
B: ExactSizeIterator
- ,
{
ChainSeq { chain: a.chain(b) }
}
}
impl Iterator for ChainSeq
where
A: Iterator,
B: Iterator
- ,
{
type Item = A::Item;
fn next(&mut self) -> Option {
self.chain.next()
}
fn size_hint(&self) -> (usize, Option) {
self.chain.size_hint()
}
}
impl ExactSizeIterator for ChainSeq
where
A: ExactSizeIterator,
B: ExactSizeIterator
- ,
{
}
impl DoubleEndedIterator for ChainSeq
where
A: DoubleEndedIterator,
B: DoubleEndedIterator
- ,
{
fn next_back(&mut self) -> Option {
self.chain.next_back()
}
}