7f28ade922a4b733ee29ea0c68bc957ff587c51b
[python_utils.git] / list_utils.py
1 #!/usr/bin/env python3
2
3 """Some useful(?) utilities for dealing with Lists."""
4
5 import random
6 from collections import Counter
7 from itertools import islice
8 from typing import Any, Iterator, List, MutableSequence, Sequence, Tuple
9
10
11 def shard(lst: List[Any], size: int) -> Iterator[Any]:
12     """
13     Yield successive size-sized shards from lst.
14
15     >>> for sublist in shard([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 3):
16     ...     [_ for _ in sublist]
17     [1, 2, 3]
18     [4, 5, 6]
19     [7, 8, 9]
20     [10, 11, 12]
21
22     """
23     for x in range(0, len(lst), size):
24         yield islice(lst, x, x + size)
25
26
27 def flatten(lst: List[Any]) -> List[Any]:
28     """
29     Flatten out a list:
30
31     >>> flatten([ 1, [2, 3, 4, [5], 6], 7, [8, [9]]])
32     [1, 2, 3, 4, 5, 6, 7, 8, 9]
33
34     """
35     if len(lst) == 0:
36         return lst
37     if isinstance(lst[0], list):
38         return flatten(lst[0]) + flatten(lst[1:])
39     return lst[:1] + flatten(lst[1:])
40
41
42 def prepend(item: Any, lst: List[Any]) -> List[Any]:
43     """
44     Prepend an item to a list.
45
46     >>> prepend('foo', ['bar', 'baz'])
47     ['foo', 'bar', 'baz']
48
49     """
50     lst.insert(0, item)
51     return lst
52
53
54 def remove_list_if_one_element(lst: List[Any]) -> Any:
55     """
56     Remove the list and return the 0th element iff its length is one.
57
58     >>> remove_list_if_one_element([1234])
59     1234
60
61     >>> remove_list_if_one_element([1, 2, 3, 4])
62     [1, 2, 3, 4]
63
64     """
65     if len(lst) == 1:
66         return lst[0]
67     else:
68         return lst
69
70
71 def population_counts(lst: Sequence[Any]) -> Counter:
72     """
73     Return a population count mapping for the list (i.e. the keys are
74     list items and the values are the number of occurrances of that
75     list item in the original list.
76
77     >>> population_counts([1, 1, 1, 2, 2, 3, 3, 3, 4])
78     Counter({1: 3, 3: 3, 2: 2, 4: 1})
79
80     """
81     return Counter(lst)
82
83
84 def most_common(lst: List[Any], *, count=1) -> Any:
85
86     """
87     Return the most common item in the list.  In the case of ties,
88     which most common item is returned will be random.
89
90     >>> most_common([1, 1, 1, 2, 2, 3, 3, 3, 3, 4, 4])
91     3
92
93     >>> most_common([1, 1, 1, 2, 2, 3, 3, 3, 3, 4, 4], count=2)
94     [3, 1]
95
96     """
97     p = population_counts(lst)
98     return remove_list_if_one_element([_[0] for _ in p.most_common()[0:count]])
99
100
101 def least_common(lst: List[Any], *, count=1) -> Any:
102     """
103     Return the least common item in the list.  In the case of
104     ties, which least common item is returned will be random.
105
106     >>> least_common([1, 1, 1, 2, 2, 3, 3, 3, 4])
107     4
108
109     >>> least_common([1, 1, 1, 2, 2, 3, 3, 3, 4], count=2)
110     [4, 2]
111
112     """
113     p = population_counts(lst)
114     mc = p.most_common()[-count:]
115     mc.reverse()
116     return remove_list_if_one_element([_[0] for _ in mc])
117
118
119 def dedup_list(lst: List[Any]) -> List[Any]:
120     """
121     Remove duplicates from the list performantly.
122
123     >>> dedup_list([1, 2, 1, 3, 3, 4, 2, 3, 4, 5, 1])
124     [1, 2, 3, 4, 5]
125
126     """
127     return list(set(lst))
128
129
130 def uniq(lst: List[Any]) -> List[Any]:
131     """
132     Alias for dedup_list.
133     """
134     return dedup_list(lst)
135
136
137 def contains_duplicates(lst: List[Any]) -> bool:
138     """
139     Does the list contian duplicate elements or not?
140
141     >>> lst = [1, 2, 1, 3, 3, 4, 4, 5, 6, 1, 3, 4]
142     >>> contains_duplicates(lst)
143     True
144
145     >>> contains_duplicates(dedup_list(lst))
146     False
147
148     """
149     seen = set()
150     for _ in lst:
151         if _ in seen:
152             return True
153         seen.add(_)
154     return False
155
156
157 def all_unique(lst: List[Any]) -> bool:
158     """
159     Inverted alias for contains_duplicates.
160     """
161     return not contains_duplicates(lst)
162
163
164 def transpose(lst: List[Any]) -> List[Any]:
165     """
166     Transpose a list of lists.
167
168     >>> lst = [[1, 2], [3, 4], [5, 6]]
169     >>> transpose(lst)
170     [[1, 3, 5], [2, 4, 6]]
171
172     """
173     transposed = zip(*lst)
174     return [list(_) for _ in transposed]
175
176
177 def ngrams(lst: Sequence[Any], n):
178     """
179     Return the ngrams in the sequence.
180
181     >>> seq = 'encyclopedia'
182     >>> for _ in ngrams(seq, 3):
183     ...     _
184     'enc'
185     'ncy'
186     'cyc'
187     'ycl'
188     'clo'
189     'lop'
190     'ope'
191     'ped'
192     'edi'
193     'dia'
194
195     >>> seq = ['this', 'is', 'an', 'awesome', 'test']
196     >>> for _ in ngrams(seq, 3):
197     ...     _
198     ['this', 'is', 'an']
199     ['is', 'an', 'awesome']
200     ['an', 'awesome', 'test']
201     """
202     for i in range(len(lst) - n + 1):
203         yield lst[i : i + n]
204
205
206 def permute(seq: str):
207     """
208     Returns all permutations of a sequence; takes O(N!) time.
209
210     >>> for x in permute('cat'):
211     ...     print(x)
212     cat
213     cta
214     act
215     atc
216     tca
217     tac
218
219     """
220     yield from _permute(seq, "")
221
222
223 def _permute(seq: str, path: str):
224     seq_len = len(seq)
225     if seq_len == 0:
226         yield path
227
228     for i in range(seq_len):
229         car = seq[i]
230         left = seq[0:i]
231         right = seq[i + 1 :]
232         cdr = left + right
233         yield from _permute(cdr, path + car)
234
235
236 def shuffle(seq: MutableSequence[Any]) -> MutableSequence[Any]:
237     """Shuffles a sequence into a random order.
238
239     >>> random.seed(22)
240     >>> shuffle([1, 2, 3, 4, 5])
241     [3, 4, 1, 5, 2]
242
243     >>> shuffle('example')
244     'empaelx'
245
246     """
247     if isinstance(seq, str):
248         import string_utils
249
250         return string_utils.shuffle(seq)
251     else:
252         random.shuffle(seq)
253         return seq
254
255
256 def scramble(seq: MutableSequence[Any]) -> MutableSequence[Any]:
257     return shuffle(seq)
258
259
260 def binary_search(lst: Sequence[Any], target: Any, *, sanity_check=False) -> Tuple[bool, int]:
261     """Performs a binary search on lst (which must already be sorted).
262     Returns a Tuple composed of a bool which indicates whether the
263     target was found and an int which indicates the index closest to
264     target whether it was found or not.
265
266     >>> a = [1, 4, 5, 6, 7, 9, 10, 11]
267     >>> binary_search(a, 4)
268     (True, 1)
269
270     >>> binary_search(a, 12)
271     (False, 8)
272
273     >>> binary_search(a, 3)
274     (False, 1)
275
276     >>> binary_search(a, 2)
277     (False, 1)
278
279     >>> a.append(9)
280     >>> binary_search(a, 4, sanity_check=True)
281     Traceback (most recent call last):
282     ...
283     AssertionError
284
285     """
286     if sanity_check:
287         last = None
288         for x in lst:
289             if last is not None:
290                 assert x >= last  # This asserts iff the list isn't sorted
291             last = x  # in ascending order.
292     return _binary_search(lst, target, 0, len(lst) - 1)
293
294
295 def _binary_search(lst: Sequence[Any], target: Any, low: int, high: int) -> Tuple[bool, int]:
296     if high >= low:
297         mid = (high + low) // 2
298         if lst[mid] == target:
299             return (True, mid)
300         elif lst[mid] > target:
301             return _binary_search(lst, target, low, mid - 1)
302         else:
303             return _binary_search(lst, target, mid + 1, high)
304     else:
305         return (False, low)
306
307
308 if __name__ == '__main__':
309     import doctest
310
311     doctest.testmod()