More spring cleaning.
[pyutils.git] / src / pyutils / iter_utils.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2023, Scott Gasch
4
5 """A collection of :class:`Iterator` subclasses that can be composed
6 with another iterator and provide extra functionality:
7
8     + :class:`PeekingIterator`
9     + :class:`PushbackIterator`
10     + :class:`SamplingIterator`
11
12 """
13
14 import random
15 from collections.abc import Iterator
16 from typing import Any, List, Optional
17
18
19 class PeekingIterator(Iterator):
20     """An iterator that lets you :meth:`peek` at the next item on deck.
21     Returns None when there is no next item (i.e. when
22     :meth:`__next__` will produce a `StopIteration` exception).
23
24     >>> p = PeekingIterator(iter(range(3)))
25     >>> p.__next__()
26     0
27     >>> p.peek()
28     1
29     >>> p.peek()
30     1
31     >>> p.__next__()
32     1
33     >>> p.__next__()
34     2
35     >>> p.peek() == None
36     True
37     >>> p.__next__()
38     Traceback (most recent call last):
39       ...
40     StopIteration
41     """
42
43     def __init__(self, source_iter: Iterator):
44         """
45         Args:
46             source_iter: the iterator we want to peek at
47         """
48         self.source_iter = source_iter
49         self.on_deck: List[Any] = []
50
51     def __iter__(self) -> Iterator:
52         return self
53
54     def __next__(self) -> Any:
55         if len(self.on_deck) > 0:
56             return self.on_deck.pop()
57         else:
58             item = self.source_iter.__next__()
59             return item
60
61     def peek(self) -> Optional[Any]:
62         """Peek at the upcoming value on the top of our contained
63         :py:class:`Iterator` non-destructively (i.e. calling :meth:`__next__` will
64         still produce the peeked value).
65
66         Returns:
67             The value that will be produced by the contained iterator next
68             or None if the contained Iterator is exhausted and will raise
69             `StopIteration` when read.
70
71         """
72         if len(self.on_deck) > 0:
73             return self.on_deck[0]
74         try:
75             item = next(self.source_iter)
76             self.on_deck.append(item)
77             return self.peek()
78         except StopIteration:
79             return None
80
81
82 class PushbackIterator(Iterator):
83     """An iterator that allows you to push items back onto the front
84     of the sequence so that they are produced before the items at the
85     front/top of the contained py:class:`Iterator`. e.g.
86
87     >>> i = PushbackIterator(iter(range(3)))
88     >>> i.__next__()
89     0
90     >>> i.push_back(99)
91     >>> i.push_back(98)
92     >>> i.__next__()
93     98
94     >>> i.__next__()
95     99
96     >>> i.__next__()
97     1
98     >>> i.__next__()
99     2
100     >>> i.push_back(100)
101     >>> i.__next__()
102     100
103     >>> i.__next__()
104     Traceback (most recent call last):
105       ...
106     StopIteration
107
108     """
109
110     def __init__(self, source_iter: Iterator):
111         self.source_iter = source_iter
112         self.pushed_back: List[Any] = []
113
114     def __iter__(self) -> Iterator:
115         return self
116
117     def __next__(self) -> Any:
118         if len(self.pushed_back) > 0:
119             return self.pushed_back.pop()
120         return self.source_iter.__next__()
121
122     def push_back(self, item: Any) -> None:
123         """Push an item onto the top of the contained iterator such that
124         the next time :meth:`__next__` is invoked we produce that item.
125
126         Args:
127             item: the item to produce from :meth:`__next__` next.
128         """
129         self.pushed_back.append(item)
130
131
132 class SamplingIterator(Iterator):
133     """An :py:class:`Iterator` that simply echoes what its
134     `source_iter` produces but also collects a random sample (of size
135     `sample_size`) from the stream that can be queried at any time.
136
137     .. note::
138         Until `sample_size` elements have been produced by the
139         `source_iter`, the sample return will be less than `sample_size`
140         elements in length.
141
142     .. note::
143         If `sample_size` is >= `len(source_iter)` then this will produce
144         a copy of `source_iter`.
145
146     >>> import collections
147     >>> import random
148
149     >>> random.seed(22)
150     >>> s = SamplingIterator(iter(range(100)), 10)
151     >>> s.__next__()
152     0
153
154     >>> s.__next__()
155     1
156
157     >>> s.get_sample()
158     [0, 1]
159
160     >>> collections.deque(s)
161     deque([2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99])
162
163     >>> s.get_sample()
164     [78, 18, 47, 83, 93, 26, 25, 73, 94, 38]
165
166     """
167
168     def __init__(self, source_iter: Iterator, sample_size: int):
169         self.source_iter = source_iter
170         self.sample_size = sample_size
171         self.resovoir: List[Any] = []
172         self.stream_length_so_far = 0
173
174     def __iter__(self) -> Iterator:
175         return self
176
177     def __next__(self) -> Any:
178         item = self.source_iter.__next__()
179         self.stream_length_so_far += 1
180
181         # Filling the resovoir
182         pop = len(self.resovoir)
183         if pop < self.sample_size:
184             self.resovoir.append(item)
185             if self.sample_size == (pop + 1):  # just finished filling...
186                 random.shuffle(self.resovoir)
187
188         # Swap this item for one in the resovoir with probabilty
189         # sample_size / stream_length_so_far.  See:
190         #
191         # https://en.wikipedia.org/wiki/Reservoir_sampling
192         else:
193             r = random.randint(0, self.stream_length_so_far)
194             if r < self.sample_size:
195                 self.resovoir[r] = item
196         return item
197
198     def get_sample(self) -> List[Any]:
199         """
200         Returns:
201             The current sample set populated randomly from the items
202             returned by the contained :class:`Iterator` so far.
203
204         .. note::
205             Until `sample_size` elements have been produced by the
206             `source_iter`, the sample return will be less than `sample_size`
207             elements in length.
208
209         .. note::
210             If `sample_size` is >= `len(source_iter)` then this will produce
211             a copy of `source_iter`.
212         """
213         return self.resovoir
214
215
216 if __name__ == "__main__":
217     import doctest
218
219     doctest.testmod()