More cleanup, yey!
[python_utils.git] / list_utils.py
1 #!/usr/bin/env python3
2
3 """Some useful(?) utilities for dealing with Lists."""
4
5 from collections import Counter
6 from itertools import islice
7 from typing import Any, Iterator, List, Sequence, Tuple
8
9
10 def shard(lst: List[Any], size: int) -> Iterator[Any]:
11     """
12     Yield successive size-sized shards from lst.
13
14     >>> for sublist in shard([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 3):
15     ...     [_ for _ in sublist]
16     [1, 2, 3]
17     [4, 5, 6]
18     [7, 8, 9]
19     [10, 11, 12]
20
21     """
22     for x in range(0, len(lst), size):
23         yield islice(lst, x, x + size)
24
25
26 def flatten(lst: List[Any]) -> List[Any]:
27     """
28     Flatten out a list:
29
30     >>> flatten([ 1, [2, 3, 4, [5], 6], 7, [8, [9]]])
31     [1, 2, 3, 4, 5, 6, 7, 8, 9]
32
33     """
34     if len(lst) == 0:
35         return lst
36     if isinstance(lst[0], list):
37         return flatten(lst[0]) + flatten(lst[1:])
38     return lst[:1] + flatten(lst[1:])
39
40
41 def prepend(item: Any, lst: List[Any]) -> List[Any]:
42     """
43     Prepend an item to a list.
44
45     >>> prepend('foo', ['bar', 'baz'])
46     ['foo', 'bar', 'baz']
47
48     """
49     lst.insert(0, item)
50     return lst
51
52
53 def remove_list_if_one_element(lst: List[Any]) -> Any:
54     """
55     Remove the list and return the 0th element iff its length is one.
56
57     >>> remove_list_if_one_element([1234])
58     1234
59
60     >>> remove_list_if_one_element([1, 2, 3, 4])
61     [1, 2, 3, 4]
62
63     """
64     if len(lst) == 1:
65         return lst[0]
66     else:
67         return lst
68
69
70 def population_counts(lst: Sequence[Any]) -> Counter:
71     """
72     Return a population count mapping for the list (i.e. the keys are
73     list items and the values are the number of occurrances of that
74     list item in the original list.
75
76     >>> population_counts([1, 1, 1, 2, 2, 3, 3, 3, 4])
77     Counter({1: 3, 3: 3, 2: 2, 4: 1})
78
79     """
80     return Counter(lst)
81
82
83 def most_common(lst: List[Any], *, count=1) -> Any:
84
85     """
86     Return the most common item in the list.  In the case of ties,
87     which most common item is returned will be random.
88
89     >>> most_common([1, 1, 1, 2, 2, 3, 3, 3, 3, 4, 4])
90     3
91
92     >>> most_common([1, 1, 1, 2, 2, 3, 3, 3, 3, 4, 4], count=2)
93     [3, 1]
94
95     """
96     p = population_counts(lst)
97     return remove_list_if_one_element([_[0] for _ in p.most_common()[0:count]])
98
99
100 def least_common(lst: List[Any], *, count=1) -> Any:
101     """
102     Return the least common item in the list.  In the case of
103     ties, which least common item is returned will be random.
104
105     >>> least_common([1, 1, 1, 2, 2, 3, 3, 3, 4])
106     4
107
108     >>> least_common([1, 1, 1, 2, 2, 3, 3, 3, 4], count=2)
109     [4, 2]
110
111     """
112     p = population_counts(lst)
113     mc = p.most_common()[-count:]
114     mc.reverse()
115     return remove_list_if_one_element([_[0] for _ in mc])
116
117
118 def dedup_list(lst: List[Any]) -> List[Any]:
119     """
120     Remove duplicates from the list performantly.
121
122     >>> dedup_list([1, 2, 1, 3, 3, 4, 2, 3, 4, 5, 1])
123     [1, 2, 3, 4, 5]
124
125     """
126     return list(set(lst))
127
128
129 def uniq(lst: List[Any]) -> List[Any]:
130     """
131     Alias for dedup_list.
132     """
133     return dedup_list(lst)
134
135
136 def contains_duplicates(lst: List[Any]) -> bool:
137     """
138     Does the list contian duplicate elements or not?
139
140     >>> lst = [1, 2, 1, 3, 3, 4, 4, 5, 6, 1, 3, 4]
141     >>> contains_duplicates(lst)
142     True
143
144     >>> contains_duplicates(dedup_list(lst))
145     False
146
147     """
148     seen = set()
149     for _ in lst:
150         if _ in seen:
151             return True
152         seen.add(_)
153     return False
154
155
156 def all_unique(lst: List[Any]) -> bool:
157     """
158     Inverted alias for contains_duplicates.
159     """
160     return not contains_duplicates(lst)
161
162
163 def transpose(lst: List[Any]) -> List[Any]:
164     """
165     Transpose a list of lists.
166
167     >>> lst = [[1, 2], [3, 4], [5, 6]]
168     >>> transpose(lst)
169     [[1, 3, 5], [2, 4, 6]]
170
171     """
172     transposed = zip(*lst)
173     return [list(_) for _ in transposed]
174
175
176 def ngrams(lst: Sequence[Any], n):
177     """
178     Return the ngrams in the sequence.
179
180     >>> seq = 'encyclopedia'
181     >>> for _ in ngrams(seq, 3):
182     ...     _
183     'enc'
184     'ncy'
185     'cyc'
186     'ycl'
187     'clo'
188     'lop'
189     'ope'
190     'ped'
191     'edi'
192     'dia'
193
194     >>> seq = ['this', 'is', 'an', 'awesome', 'test']
195     >>> for _ in ngrams(seq, 3):
196     ...     _
197     ['this', 'is', 'an']
198     ['is', 'an', 'awesome']
199     ['an', 'awesome', 'test']
200     """
201     for i in range(len(lst) - n + 1):
202         yield lst[i : i + n]
203
204
205 def permute(seq: str):
206     """
207     Returns all permutations of a sequence; takes O(N!) time.
208
209     >>> for x in permute('cat'):
210     ...     print(x)
211     cat
212     cta
213     act
214     atc
215     tca
216     tac
217
218     """
219     yield from _permute(seq, "")
220
221
222 def _permute(seq: str, path: str):
223     seq_len = len(seq)
224     if seq_len == 0:
225         yield path
226
227     for i in range(seq_len):
228         car = seq[i]
229         left = seq[0:i]
230         right = seq[i + 1 :]
231         cdr = left + right
232         yield from _permute(cdr, path + car)
233
234
235 def binary_search(lst: Sequence[Any], target: Any, *, sanity_check=False) -> Tuple[bool, int]:
236     """Performs a binary search on lst (which must already be sorted).
237     Returns a Tuple composed of a bool which indicates whether the
238     target was found and an int which indicates the index closest to
239     target whether it was found or not.
240
241     >>> a = [1, 4, 5, 6, 7, 9, 10, 11]
242     >>> binary_search(a, 4)
243     (True, 1)
244
245     >>> binary_search(a, 12)
246     (False, 8)
247
248     >>> binary_search(a, 3)
249     (False, 1)
250
251     >>> binary_search(a, 2)
252     (False, 1)
253
254     >>> a.append(9)
255     >>> binary_search(a, 4, sanity_check=True)
256     Traceback (most recent call last):
257     ...
258     AssertionError
259
260     """
261     if sanity_check:
262         last = None
263         for x in lst:
264             if last is not None:
265                 assert x >= last  # This asserts iff the list isn't sorted
266             last = x  # in ascending order.
267     return _binary_search(lst, target, 0, len(lst) - 1)
268
269
270 def _binary_search(lst: Sequence[Any], target: Any, low: int, high: int) -> Tuple[bool, int]:
271     if high >= low:
272         mid = (high + low) // 2
273         if lst[mid] == target:
274             return (True, mid)
275         elif lst[mid] > target:
276             return _binary_search(lst, target, low, mid - 1)
277         else:
278             return _binary_search(lst, target, mid + 1, high)
279     else:
280         return (False, low)
281
282
283 if __name__ == '__main__':
284     import doctest
285
286     doctest.testmod()