Make sanity check optional.
[python_utils.git] / list_utils.py
1 #!/usr/bin/env python3
2
3 from collections import Counter
4 from itertools import islice
5 from typing import Any, Iterator, List, Mapping, Sequence, Tuple
6
7
8 def shard(lst: List[Any], size: int) -> Iterator[Any]:
9     """
10     Yield successive size-sized shards from lst.
11
12     >>> for sublist in shard([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 3):
13     ...     [_ for _ in sublist]
14     [1, 2, 3]
15     [4, 5, 6]
16     [7, 8, 9]
17     [10, 11, 12]
18
19     """
20     for x in range(0, len(lst), size):
21         yield islice(lst, x, x + size)
22
23
24 def flatten(lst: List[Any]) -> List[Any]:
25     """
26     Flatten out a list:
27
28     >>> flatten([ 1, [2, 3, 4, [5], 6], 7, [8, [9]]])
29     [1, 2, 3, 4, 5, 6, 7, 8, 9]
30
31     """
32     if len(lst) == 0:
33         return lst
34     if isinstance(lst[0], list):
35         return flatten(lst[0]) + flatten(lst[1:])
36     return lst[:1] + flatten(lst[1:])
37
38
39 def prepend(item: Any, lst: List[Any]) -> List[Any]:
40     """
41     Prepend an item to a list.
42
43     >>> prepend('foo', ['bar', 'baz'])
44     ['foo', 'bar', 'baz']
45
46     """
47     lst.insert(0, item)
48     return lst
49
50
51 def remove_list_if_one_element(lst: List[Any]) -> Any:
52     """
53     Remove the list and return the 0th element iff its length is one.
54
55     >>> remove_list_if_one_element([1234])
56     1234
57
58     >>> remove_list_if_one_element([1, 2, 3, 4])
59     [1, 2, 3, 4]
60
61     """
62     if len(lst) == 1:
63         return lst[0]
64     else:
65         return lst
66
67
68 def population_counts(lst: List[Any]) -> Mapping[Any, int]:
69     """
70     Return a population count mapping for the list (i.e. the keys are
71     list items and the values are the number of occurrances of that
72     list item in the original list.
73
74     >>> population_counts([1, 1, 1, 2, 2, 3, 3, 3, 4])
75     Counter({1: 3, 3: 3, 2: 2, 4: 1})
76
77     """
78     return Counter(lst)
79
80
81 def most_common(lst: List[Any], *, count=1) -> Any:
82
83     """
84     Return the most common item in the list.  In the case of ties,
85     which most common item is returned will be random.
86
87     >>> most_common([1, 1, 1, 2, 2, 3, 3, 3, 3, 4, 4])
88     3
89
90     >>> most_common([1, 1, 1, 2, 2, 3, 3, 3, 3, 4, 4], count=2)
91     [3, 1]
92
93     """
94     p = population_counts(lst)
95     return remove_list_if_one_element([_[0] for _ in p.most_common()[0:count]])
96
97
98 def least_common(lst: List[Any], *, count=1) -> Any:
99     """
100     Return the least common item in the list.  In the case of
101     ties, which least common item is returned will be random.
102
103     >>> least_common([1, 1, 1, 2, 2, 3, 3, 3, 4])
104     4
105
106     >>> least_common([1, 1, 1, 2, 2, 3, 3, 3, 4], count=2)
107     [4, 2]
108
109     """
110     p = population_counts(lst)
111     mc = p.most_common()[-count:]
112     mc.reverse()
113     return remove_list_if_one_element([_[0] for _ in mc])
114
115
116 def dedup_list(lst: List[Any]) -> List[Any]:
117     """
118     Remove duplicates from the list performantly.
119
120     >>> dedup_list([1, 2, 1, 3, 3, 4, 2, 3, 4, 5, 1])
121     [1, 2, 3, 4, 5]
122
123     """
124     return list(set(lst))
125
126
127 def uniq(lst: List[Any]) -> List[Any]:
128     """
129     Alias for dedup_list.
130     """
131     return dedup_list(lst)
132
133
134 def contains_duplicates(lst: List[Any]) -> bool:
135     """
136     Does the list contian duplicate elements or not?
137
138     >>> lst = [1, 2, 1, 3, 3, 4, 4, 5, 6, 1, 3, 4]
139     >>> contains_duplicates(lst)
140     True
141
142     >>> contains_duplicates(dedup_list(lst))
143     False
144
145     """
146     seen = set()
147     for _ in lst:
148         if _ in seen:
149             return True
150         seen.add(_)
151     return False
152
153
154 def all_unique(lst: List[Any]) -> bool:
155     """
156     Inverted alias for contains_duplicates.
157     """
158     return not contains_duplicates(lst)
159
160
161 def transpose(lst: List[Any]) -> List[Any]:
162     """
163     Transpose a list of lists.
164
165     >>> lst = [[1, 2], [3, 4], [5, 6]]
166     >>> transpose(lst)
167     [[1, 3, 5], [2, 4, 6]]
168
169     """
170     transposed = zip(*lst)
171     return [list(_) for _ in transposed]
172
173
174 def ngrams(lst: Sequence[Any], n):
175     """
176     Return the ngrams in the sequence.
177
178     >>> seq = 'encyclopedia'
179     >>> for _ in ngrams(seq, 3):
180     ...     _
181     'enc'
182     'ncy'
183     'cyc'
184     'ycl'
185     'clo'
186     'lop'
187     'ope'
188     'ped'
189     'edi'
190     'dia'
191
192     >>> seq = ['this', 'is', 'an', 'awesome', 'test']
193     >>> for _ in ngrams(seq, 3):
194     ...     _
195     ['this', 'is', 'an']
196     ['is', 'an', 'awesome']
197     ['an', 'awesome', 'test']
198     """
199     for i in range(len(lst) - n + 1):
200         yield lst[i : i + n]
201
202
203 def permute(seq: Sequence[Any]):
204     """
205     Returns all permutations of a sequence; takes O(N^2) time.
206
207     >>> for x in permute('cat'):
208     ...     print(x)
209     cat
210     cta
211     act
212     atc
213     tca
214     tac
215
216     """
217     yield from _permute(seq, "")
218
219
220 def _permute(seq: Sequence[Any], path):
221     if len(seq) == 0:
222         yield path
223
224     for i in range(len(seq)):
225         car = seq[i]
226         left = seq[0:i]
227         right = seq[i + 1 :]
228         cdr = left + right
229         yield from _permute(cdr, path + car)
230
231
232 def binary_search(lst: Sequence[Any], target: Any, *, sanity_check=False) -> Tuple[bool, int]:
233     """Performs a binary search on lst (which must already be sorted).
234     Returns a Tuple composed of a bool which indicates whether the
235     target was found and an int which indicates the index closest to
236     target whether it was found or not.
237
238     >>> a = [1, 4, 5, 6, 7, 9, 10, 11]
239     >>> binary_search(a, 4)
240     (True, 1)
241
242     >>> binary_search(a, 12)
243     (False, 8)
244
245     >>> binary_search(a, 3)
246     (False, 1)
247
248     >>> binary_search(a, 2)
249     (False, 1)
250
251     >>> a.append(9)
252     >>> binary_search(a, 4, sanity_check=True)
253     Traceback (most recent call last):
254     ...
255     AssertionError
256
257     """
258     if sanity_check:
259         last = None
260         for x in lst:
261             if last is not None:
262                 assert x >= last  # This asserts iff the list isn't sorted
263             last = x  # in ascending order.
264     return _binary_search(lst, target, 0, len(lst) - 1)
265
266
267 def _binary_search(
268     lst: Sequence[Any], target: Any, low: int, high: int
269 ) -> Tuple[bool, int]:
270     if high >= low:
271         mid = (high + low) // 2
272         if lst[mid] == target:
273             return (True, mid)
274         elif lst[mid] > target:
275             return _binary_search(lst, target, low, mid - 1)
276         else:
277             return _binary_search(lst, target, mid + 1, high)
278     else:
279         return (False, low)
280
281
282 if __name__ == '__main__':
283     import doctest
284
285     doctest.testmod()