joinery/
iter.rs

1//! Joinery iterator and related types and traits
2
3use core::{
4    fmt::{self, Debug, Display, Formatter},
5    iter::FusedIterator,
6    mem,
7};
8
9#[cfg(feature = "nightly")]
10use core::{iter::TrustedLen, ops::Try};
11
12use crate::{
13    join::{Join, Joinable},
14    separators::NoSeparator,
15};
16
17/// Specialized helper struct to allow adapting any [`Iterator`] into a [`Join`].
18///
19/// [`Join`] requires the underlying object to be `&T: IntoIterator`, so that
20/// it can be iterated over when formatting via [`Display`]. This works fine for
21/// collection types like [`Vec`](https://doc.rust-lang.org/std/vec/struct.Vec.html),
22/// but it doesn't work for arbitrary iterators. However, because many iterators
23/// are cheaply clonable (because they often just contain a reference to the
24/// underlying sequence), we can use this adapter to create an `&T: IntoIterator`
25/// type which can be displayed by `Join`.
26#[derive(Debug, Clone, Eq, PartialEq)]
27#[repr(transparent)]
28pub struct CloneIterator<I>(I);
29
30impl<I: Iterator> IntoIterator for CloneIterator<I> {
31    type IntoIter = I;
32    type Item = I::Item;
33
34    /// Convert the adapter back into the underlying iterator.
35    fn into_iter(self) -> Self::IntoIter {
36        self.0
37    }
38}
39
40impl<'a, I: Iterator + Clone> IntoIterator for &'a CloneIterator<I> {
41    type IntoIter = I;
42    type Item = I::Item;
43
44    /// Create a referential iterator by cloning the underlying iterator. Note
45    /// that this will have the same `Item` type as the underlying iterator,
46    /// rather than references to those items.
47    fn into_iter(self) -> Self::IntoIter {
48        self.0.clone()
49    }
50}
51
52/// A trait for converting [`Iterator`]s into [`Join`] instances or [`JoinIter`]
53/// iterators.
54///
55/// This trait serves the same purpose as [`Joinable`], but is implemented for `Iterator`
56/// types. The main difference between [`JoinableIterator`] and [`Joinable`] is that,
57/// because iterators generally don't implement `&T: IntoIterator`, we need a different
58/// mechanism to allow for immutably iterating (which is required for [`Join`]'s implementation
59/// of [`Display`]).
60pub trait JoinableIterator: Iterator + Sized {
61    /// Convert a cloneable iterator into a [`Join`] instance. Whenever the [`Join`]
62    /// needs to immutabley iterate over the underlying iterator (for instance, when
63    /// formatting it with [`Display`]), the underlying iterator is cloned. For most
64    /// iterator types this is a cheap operation, because the iterator contains just
65    /// a reference to the underlying collection.
66    ///
67    /// # Examples
68    ///
69    /// ```
70    /// use joinery::JoinableIterator;
71    ///
72    /// let result = (0..4).map(|x| x * 2).join_with(", ").to_string();
73    ///
74    /// assert_eq!(result, "0, 2, 4, 6");
75    /// ```
76    fn join_with<S>(self, sep: S) -> Join<CloneIterator<Self>, S>
77    where
78        Self: Clone,
79    {
80        CloneIterator(self).join_with(sep)
81    }
82
83    /// Convert a [cloneable][Clone] iterator into a [`Join`] instance with no separator.
84    /// When formatted with [`Display`], the elements of the iterator will be directly
85    /// concatenated.
86    /// # Examples
87    ///
88    /// ```
89    /// use joinery::JoinableIterator;
90    ///
91    /// let result = (0..4).map(|x| x * 2).join_concat().to_string();
92    ///
93    /// assert_eq!(result, "0246");
94    /// ```
95    fn join_concat(self) -> Join<CloneIterator<Self>, NoSeparator>
96    where
97        Self: Clone,
98    {
99        self.join_with(NoSeparator)
100    }
101
102    /// Create an iterator which interspeses the elements of this iterator with
103    /// a separator. See [`JoinIter`] for more details.
104    ///
105    /// # Examples
106    ///
107    /// ```
108    /// use joinery::{JoinableIterator, JoinItem};
109    ///
110    /// let mut iter = (0..3).map(|x| x * 2).iter_join_with(", ");
111    ///
112    /// assert_eq!(iter.next(), Some(JoinItem::Element(0)));
113    /// assert_eq!(iter.next(), Some(JoinItem::Separator(", ")));
114    /// assert_eq!(iter.next(), Some(JoinItem::Element(2)));
115    /// assert_eq!(iter.next(), Some(JoinItem::Separator(", ")));
116    /// assert_eq!(iter.next(), Some(JoinItem::Element(4)));
117    /// assert_eq!(iter.next(), None);
118    /// ```
119    fn iter_join_with<S>(self, sep: S) -> JoinIter<Self, S> {
120        JoinIter::new(self, sep)
121    }
122}
123
124impl<T: Iterator> JoinableIterator for T {}
125
126/// Enum representing the elements of a [`JoinIter`].
127#[derive(Debug, Clone, PartialEq, Eq, Hash)]
128pub enum JoinItem<T, S> {
129    /// An element from the underlying iterator
130    Element(T),
131    /// A separator between two elements
132    Separator(S),
133}
134
135impl<T, S> JoinItem<T, S> {
136    /// Convert a [`JoinItem`] into a common type `R`, in the case where both
137    /// `T` and `S` can be converted to `R`. Unfortunately, due to potentially
138    /// conflicting implementations, we can't implement [`Into<R>`][Into] for
139    /// [`JoinItem`].
140    pub fn into<R>(self) -> R
141    where
142        T: Into<R>,
143        S: Into<R>,
144    {
145        match self {
146            JoinItem::Element(el) => el.into(),
147            JoinItem::Separator(sep) => sep.into(),
148        }
149    }
150}
151
152impl<R, T: AsRef<R>, S: AsRef<R>> AsRef<R> for JoinItem<T, S> {
153    /// Get a reference to a common type `R` from a [`JoinItem`], in the case where
154    /// both `T` and `S` implement [`AsRef<R>`][AsRef]
155    fn as_ref(&self) -> &R {
156        match self {
157            JoinItem::Element(el) => el.as_ref(),
158            JoinItem::Separator(sep) => sep.as_ref(),
159        }
160    }
161}
162
163impl<R, T: AsMut<R>, S: AsMut<R>> AsMut<R> for JoinItem<T, S> {
164    /// Get a mutable reference to a common type `R` from a [`JoinItem`], in the
165    /// case where both `T` and `S` implement [`AsMut<R>`][AsMut]
166    fn as_mut(&mut self) -> &mut R {
167        match self {
168            JoinItem::Element(el) => el.as_mut(),
169            JoinItem::Separator(sep) => sep.as_mut(),
170        }
171    }
172}
173
174impl<T: Display, S: Display> Display for JoinItem<T, S> {
175    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
176        match self {
177            JoinItem::Element(el) => el.fmt(f),
178            JoinItem::Separator(sep) => sep.fmt(f),
179        }
180    }
181}
182
183#[derive(Debug, Clone, Copy)]
184enum JoinIterState<T> {
185    /// Unconditionally yield the first item from the inner iterator
186    Initial,
187
188    /// Yield a separator if there's an item in the inner iterator
189    Separator,
190
191    /// We got an item from the iterator and yielded a separator, so yield
192    /// an item
193    Element(T),
194}
195
196/// An iterator for a [`Join`].
197///
198/// Emits the elements of the [`Join`]'s underlying iterator, interspersed with
199/// its separator. Note that it uses [`clone`][Clone::clone] to generate copies
200/// of the separator while iterating, but also keep in mind that in most cases
201/// the [`JoinItem`] instance will have a trivially cloneable separator, such
202/// as [`&`](https://doc.rust-lang.org/std/primitive.reference.html)[`str`][str]
203/// or [`char`].
204///
205/// # Examples
206///
207/// Via [`IntoIterator`]:
208///
209/// ```
210/// use joinery::{Joinable, JoinItem};
211///
212/// let join = vec![1, 2, 3].join_with(" ");
213/// let mut join_iter = join.into_iter();
214///
215/// assert_eq!(join_iter.next(), Some(JoinItem::Element(1)));
216/// assert_eq!(join_iter.next(), Some(JoinItem::Separator(" ")));
217/// assert_eq!(join_iter.next(), Some(JoinItem::Element(2)));
218/// assert_eq!(join_iter.next(), Some(JoinItem::Separator(" ")));
219/// assert_eq!(join_iter.next(), Some(JoinItem::Element(3)));
220/// assert_eq!(join_iter.next(), None);
221/// ```
222///
223/// Via [`iter_join_with`][JoinableIterator::iter_join_with]:
224///
225/// ```
226/// use joinery::{JoinableIterator, JoinItem};
227///
228/// let mut iter = (0..6)
229///     .filter(|x| x % 2 == 0)
230///     .map(|x| x * 2)
231///     .iter_join_with(", ");
232///
233/// assert_eq!(iter.next(), Some(JoinItem::Element(0)));
234/// assert_eq!(iter.next(), Some(JoinItem::Separator(", ")));
235/// assert_eq!(iter.next(), Some(JoinItem::Element(4)));
236/// assert_eq!(iter.next(), Some(JoinItem::Separator(", ")));
237/// assert_eq!(iter.next(), Some(JoinItem::Element(8)));
238/// assert_eq!(iter.next(), None);
239/// ```
240#[must_use]
241pub struct JoinIter<Iter: Iterator, Sep> {
242    iter: Iter,
243    sep: Sep,
244    state: JoinIterState<Iter::Item>,
245}
246
247impl<I: Iterator, S> JoinIter<I, S> {
248    /// Construct a new [`JoinIter`] using an iterator and a separator
249    fn new(iter: I, sep: S) -> Self {
250        JoinIter {
251            iter,
252            sep,
253            state: JoinIterState::Initial,
254        }
255    }
256}
257
258impl<I: Iterator, S> JoinIter<I, S> {
259    /// Check if the next iteration of this iterator will (try to) return a
260    /// separator. Note that this does not check if the underlying iterator is
261    /// empty, so the next `next` call could still return `None`.
262    ///
263    /// # Examples
264    ///
265    /// ```
266    /// use joinery::{JoinableIterator, JoinItem};
267    ///
268    /// let mut join_iter = (0..3).join_with(", ").into_iter();
269    ///
270    /// assert_eq!(join_iter.is_sep_next(), false);
271    /// join_iter.next();
272    /// assert_eq!(join_iter.is_sep_next(), true);
273    /// join_iter.next();
274    /// assert_eq!(join_iter.is_sep_next(), false);
275    /// ```
276    #[inline]
277    pub fn is_sep_next(&self) -> bool {
278        matches!(self.state, JoinIterState::Separator)
279    }
280
281    /// Get a reference to the separator.
282    #[inline]
283    pub fn sep(&self) -> &S {
284        &self.sep
285    }
286}
287
288impl<I: Debug + Iterator, S: Debug> Debug for JoinIter<I, S>
289where
290    I::Item: Debug,
291{
292    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
293        f.debug_struct("JoinIter")
294            .field("iter", &self.iter)
295            .field("sep", &self.sep)
296            .field("state", &self.state)
297            .finish()
298    }
299}
300
301impl<I: Clone + Iterator, S: Clone> Clone for JoinIter<I, S>
302where
303    I::Item: Clone, // Needed because we use a peekable iterator
304{
305    fn clone(&self) -> Self {
306        JoinIter {
307            iter: self.iter.clone(),
308            sep: self.sep.clone(),
309            state: self.state.clone(),
310        }
311    }
312
313    fn clone_from(&mut self, source: &Self) {
314        self.iter.clone_from(&source.iter);
315        self.sep.clone_from(&source.sep);
316        self.state.clone_from(&source.state);
317    }
318}
319
320/// Get the size of a [`JoinIter`], given the size of the underlying iterator. If
321/// next_sep is true, the next element in the [`JoinIter`] will be the separator.
322/// Return None in the event of an overflow. This logic is provided as a separate
323/// function in the hopes that it will aid compiler optimization, and also with
324/// the intention that in the future it will be a `const fn`.
325#[inline]
326fn join_size<T>(iter_size: usize, state: &JoinIterState<T>) -> Option<usize> {
327    match *state {
328        JoinIterState::Initial => match iter_size {
329            0 => Some(0),
330            _ => (iter_size - 1).checked_mul(2)?.checked_add(1),
331        },
332        JoinIterState::Separator => iter_size.checked_mul(2),
333        JoinIterState::Element(..) => iter_size.checked_mul(2)?.checked_add(1),
334    }
335}
336
337impl<I: Iterator, S: Clone> Iterator for JoinIter<I, S> {
338    type Item = JoinItem<I::Item, S>;
339
340    /// Advance to the next item in the Join. This will either be the next
341    /// element in the underlying iterator, or a clone of the separator.
342    // We tag it inline in the hopes that the compiler can optimize loops into
343    // (mostly) branchless versions, similar to `(Join as Display)::fmt`
344    #[inline]
345    fn next(&mut self) -> Option<Self::Item> {
346        match mem::replace(&mut self.state, JoinIterState::Separator) {
347            JoinIterState::Initial => self.iter.next().map(JoinItem::Element),
348            JoinIterState::Separator => self.iter.next().map(|element| {
349                self.state = JoinIterState::Element(element);
350                JoinItem::Separator(self.sep.clone())
351            }),
352            JoinIterState::Element(element) => Some(JoinItem::Element(element)),
353        }
354    }
355
356    fn size_hint(&self) -> (usize, Option<usize>) {
357        let (min, max) = self.iter.size_hint();
358
359        let min = join_size(min, &self.state).unwrap_or(core::usize::MAX);
360        let max = max.and_then(|max| join_size(max, &self.state));
361
362        (min, max)
363    }
364
365    fn count(self) -> usize
366    where
367        Self: Sized,
368    {
369        match self.state {
370            JoinIterState::Initial => (self.iter.count() * 2).saturating_sub(1),
371            JoinIterState::Separator => self.iter.count() * 2,
372            JoinIterState::Element(_) => self.iter.count() * 2 + 1,
373        }
374    }
375
376    fn last(self) -> Option<Self::Item>
377    where
378        Self: Sized,
379    {
380        self.iter
381            .last()
382            .or(match self.state {
383                JoinIterState::Initial | JoinIterState::Separator => None,
384                JoinIterState::Element(item) => Some(item),
385            })
386            .map(JoinItem::Element)
387    }
388
389    fn fold<B, F>(mut self, init: B, mut func: F) -> B
390    where
391        F: FnMut(B, Self::Item) -> B,
392    {
393        let accum = match self.state {
394            JoinIterState::Initial => match self.iter.next() {
395                None => return init,
396                Some(element) => func(init, JoinItem::Element(element)),
397            },
398            JoinIterState::Separator => init,
399            JoinIterState::Element(element) => func(init, JoinItem::Element(element)),
400        };
401
402        self.iter.fold(accum, move |accum, element| {
403            let accum = func(accum, JoinItem::Separator(self.sep.clone()));
404            func(accum, JoinItem::Element(element))
405        })
406    }
407
408    #[cfg(feature = "nightly")]
409    fn try_fold<B, F, R>(&mut self, init: B, mut func: F) -> R
410    where
411        Self: Sized,
412        F: FnMut(B, Self::Item) -> R,
413        R: Try<Output = B>,
414    {
415        use core::ops::ControlFlow;
416
417        let accum = match mem::replace(&mut self.state, JoinIterState::Separator) {
418            JoinIterState::Initial => match self.iter.next() {
419                None => return R::from_output(init),
420                Some(element) => func(init, JoinItem::Element(element))?,
421            },
422            JoinIterState::Separator => init,
423            JoinIterState::Element(element) => func(init, JoinItem::Element(element))?,
424        };
425
426        self.iter.try_fold(accum, |accum, element| {
427            match func(accum, JoinItem::Separator(self.sep.clone())).branch() {
428                ControlFlow::Break(err) => {
429                    self.state = JoinIterState::Element(element);
430                    R::from_residual(err)
431                }
432                ControlFlow::Continue(accum) => func(accum, JoinItem::Element(element)),
433            }
434        })
435    }
436}
437
438impl<I: FusedIterator, S: Clone> FusedIterator for JoinIter<I, S> {}
439
440#[cfg(feature = "nightly")]
441unsafe impl<I: TrustedLen, S: Clone> TrustedLen for JoinIter<I, S> {}
442
443#[cfg(test)]
444mod tests {
445    use super::JoinItem::*;
446    use super::JoinableIterator;
447
448    #[test]
449    fn empty_iter() {
450        let mut join_iter = (0..0).iter_join_with(", ");
451
452        assert_eq!(join_iter.next(), None);
453        assert_eq!(join_iter.next(), None);
454    }
455
456    #[test]
457    fn single() {
458        let mut join_iter = (0..1).iter_join_with(", ");
459
460        assert_eq!(join_iter.next(), Some(Element(0)));
461        assert_eq!(join_iter.next(), None);
462        assert_eq!(join_iter.next(), None);
463    }
464
465    #[test]
466    fn few() {
467        let mut join_iter = (0..3).iter_join_with(", ");
468
469        assert_eq!(join_iter.next(), Some(Element(0)));
470        assert_eq!(join_iter.next(), Some(Separator(", ")));
471        assert_eq!(join_iter.next(), Some(Element(1)));
472        assert_eq!(join_iter.next(), Some(Separator(", ")));
473        assert_eq!(join_iter.next(), Some(Element(2)));
474        assert_eq!(join_iter.next(), None);
475        assert_eq!(join_iter.next(), None);
476    }
477
478    #[test]
479    fn regular_size_hint() {
480        let mut join_iter = (0..10).iter_join_with(", ");
481
482        for size in (0..=19).rev() {
483            assert_eq!(join_iter.size_hint(), (size, Some(size)));
484            join_iter.next();
485        }
486
487        assert_eq!(join_iter.size_hint(), (0, Some(0)));
488        join_iter.next();
489        assert_eq!(join_iter.size_hint(), (0, Some(0)));
490    }
491
492    #[test]
493    fn large_size_hint() {
494        let join_iter = (0..usize::max_value() - 10).iter_join_with(", ");
495        assert_eq!(join_iter.size_hint(), (usize::max_value(), None));
496    }
497
498    #[test]
499    fn threshold_size_hint() {
500        use core::usize::MAX as usize_max;
501        let usize_threshold = (usize_max / 2) + 1;
502
503        let mut join_iter = (0..usize_threshold + 1).iter_join_with(", ");
504        assert_eq!(join_iter.size_hint(), (usize_max, None));
505
506        join_iter.next();
507        assert_eq!(join_iter.size_hint(), (usize_max, None));
508
509        join_iter.next();
510        assert_eq!(join_iter.size_hint(), (usize_max, Some(usize_max)));
511
512        join_iter.next();
513        assert_eq!(join_iter.size_hint(), (usize_max - 1, Some(usize_max - 1)));
514    }
515
516    #[test]
517    fn partial_iteration() {
518        use std::vec::Vec;
519
520        let mut join_iter = (0..3).iter_join_with(' ');
521
522        join_iter.next();
523
524        let rest: Vec<_> = join_iter.collect();
525        assert_eq!(
526            rest,
527            [Separator(' '), Element(1), Separator(' '), Element(2),]
528        );
529    }
530
531    #[test]
532    fn fold() {
533        let content = [1, 2, 3];
534        let join_iter = content.iter().iter_join_with(4);
535
536        let sum = join_iter.fold(0, |accum, next| match next {
537            Element(el) => accum + el,
538            Separator(sep) => accum + sep,
539        });
540
541        assert_eq!(sum, 14);
542    }
543
544    #[test]
545    fn partial_fold() {
546        let content = [1, 2, 3, 4];
547        let mut join_iter = content.iter().iter_join_with(1);
548
549        join_iter.next();
550        join_iter.next();
551        join_iter.next();
552
553        let sum = join_iter.fold(0, |accum, next| match next {
554            Element(el) => accum + el,
555            Separator(sep) => accum + sep,
556        });
557
558        assert_eq!(sum, 9);
559    }
560
561    #[test]
562    fn try_fold() {
563        let content = [1, 2, 0, 3];
564        let mut join_iter = content.iter().iter_join_with(1);
565
566        let result = join_iter.try_fold(0, |accum, next| match next {
567            Separator(sep) => Ok(accum + sep),
568            Element(el) if *el == 0 => Err(accum),
569            Element(el) => Ok(accum + el),
570        });
571
572        assert_eq!(result, Err(5));
573    }
574
575    // This test exists because implementing JoinIter::try_fold in terms of
576    // JoinIter.iter::try_fold is non trivial, and the naive (incorrect) implementation
577    // fails this test.
578    #[test]
579    fn partial_try_fold() {
580        let content = [1, 2, 3];
581        let mut join_iter = content.iter().iter_join_with(1);
582
583        let _ = join_iter.try_fold(1, |_, next| match next {
584            Element(_) => Some(1),
585            Separator(_) => None,
586        });
587
588        // At this point, the remaining elements in the iterator SHOULD be E(2), S(1), E(3)
589        assert_eq!(join_iter.count(), 3);
590    }
591
592    #[test]
593    fn last_empty() {
594        let content: [i32; 0] = [];
595        let join_iter = content.iter().iter_join_with(0);
596
597        assert_eq!(join_iter.last(), None);
598    }
599
600    #[test]
601    fn last_only() {
602        let content = [1];
603        let join_iter = content.iter().iter_join_with(0);
604
605        assert_eq!(join_iter.last(), Some(Element(&1)));
606    }
607
608    #[test]
609    fn last_initial() {
610        let content = [1, 2, 3];
611        let join_iter = content.iter().iter_join_with(0);
612
613        assert_eq!(join_iter.last(), Some(Element(&3)));
614    }
615
616    #[test]
617    fn last_sep() {
618        let content = [1, 2, 3];
619        let mut join_iter = content.iter().iter_join_with(0);
620        join_iter.next().unwrap();
621
622        assert_eq!(join_iter.last(), Some(Element(&3)));
623    }
624
625    #[test]
626    fn last_element() {
627        let content = [1, 2, 3];
628        let mut join_iter = content.iter().iter_join_with(0);
629        join_iter.next().unwrap();
630        join_iter.next().unwrap();
631
632        assert_eq!(join_iter.last(), Some(Element(&3)));
633    }
634
635    #[test]
636    fn last_sep_2() {
637        let content = [1, 2];
638        let mut join_iter = content.iter().iter_join_with(0);
639        join_iter.next().unwrap();
640
641        assert_eq!(join_iter.last(), Some(Element(&2)));
642    }
643
644    #[test]
645    fn last_element_2() {
646        let content = [1, 2];
647        let mut join_iter = content.iter().iter_join_with(0);
648        join_iter.next().unwrap();
649        join_iter.next().unwrap();
650
651        assert_eq!(join_iter.last(), Some(Element(&2)));
652    }
653
654    #[test]
655    fn last_emptied() {
656        let content = [1, 2];
657        let mut join_iter = content.iter().iter_join_with(0);
658        join_iter.next().unwrap();
659        join_iter.next().unwrap();
660        join_iter.next().unwrap();
661
662        assert_eq!(join_iter.last(), None);
663    }
664}