d70159a1b2dadb61640eae20f029608cabd2f46e
[python_utils.git] / list_utils.py
1 #!/usr/bin/env python3
2
3 from collections import Counter
4 from itertools import islice
5 from typing import Any, Iterator, List, Mapping, Sequence, Tuple
6
7
8 def shard(lst: List[Any], size: int) -> Iterator[Any]:
9     """
10     Yield successive size-sized shards from lst.
11
12     >>> for sublist in shard([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 3):
13     ...     [_ for _ in sublist]
14     [1, 2, 3]
15     [4, 5, 6]
16     [7, 8, 9]
17     [10, 11, 12]
18
19     """
20     for x in range(0, len(lst), size):
21         yield islice(lst, x, x + size)
22
23
24 def flatten(lst: List[Any]) -> List[Any]:
25     """
26     Flatten out a list:
27
28     >>> flatten([ 1, [2, 3, 4, [5], 6], 7, [8, [9]]])
29     [1, 2, 3, 4, 5, 6, 7, 8, 9]
30
31     """
32     if len(lst) == 0:
33         return lst
34     if isinstance(lst[0], list):
35         return flatten(lst[0]) + flatten(lst[1:])
36     return lst[:1] + flatten(lst[1:])
37
38
39 def prepend(item: Any, lst: List[Any]) -> List[Any]:
40     """
41     Prepend an item to a list.
42
43     >>> prepend('foo', ['bar', 'baz'])
44     ['foo', 'bar', 'baz']
45
46     """
47     lst.insert(0, item)
48     return lst
49
50
51 def remove_list_if_one_element(lst: List[Any]) -> Any:
52     """
53     Remove the list and return the 0th element iff its length is one.
54
55     >>> remove_list_if_one_element([1234])
56     1234
57
58     >>> remove_list_if_one_element([1, 2, 3, 4])
59     [1, 2, 3, 4]
60
61     """
62     if len(lst) == 1:
63         return lst[0]
64     else:
65         return lst
66
67
68 def population_counts(lst: Sequence[Any]) -> Counter:
69     """
70     Return a population count mapping for the list (i.e. the keys are
71     list items and the values are the number of occurrances of that
72     list item in the original list.
73
74     >>> population_counts([1, 1, 1, 2, 2, 3, 3, 3, 4])
75     Counter({1: 3, 3: 3, 2: 2, 4: 1})
76
77     """
78     return Counter(lst)
79
80
81 def most_common(lst: List[Any], *, count=1) -> Any:
82
83     """
84     Return the most common item in the list.  In the case of ties,
85     which most common item is returned will be random.
86
87     >>> most_common([1, 1, 1, 2, 2, 3, 3, 3, 3, 4, 4])
88     3
89
90     >>> most_common([1, 1, 1, 2, 2, 3, 3, 3, 3, 4, 4], count=2)
91     [3, 1]
92
93     """
94     p = population_counts(lst)
95     return remove_list_if_one_element([_[0] for _ in p.most_common()[0:count]])
96
97
98 def least_common(lst: List[Any], *, count=1) -> Any:
99     """
100     Return the least common item in the list.  In the case of
101     ties, which least common item is returned will be random.
102
103     >>> least_common([1, 1, 1, 2, 2, 3, 3, 3, 4])
104     4
105
106     >>> least_common([1, 1, 1, 2, 2, 3, 3, 3, 4], count=2)
107     [4, 2]
108
109     """
110     p = population_counts(lst)
111     mc = p.most_common()[-count:]
112     mc.reverse()
113     return remove_list_if_one_element([_[0] for _ in mc])
114
115
116 def dedup_list(lst: List[Any]) -> List[Any]:
117     """
118     Remove duplicates from the list performantly.
119
120     >>> dedup_list([1, 2, 1, 3, 3, 4, 2, 3, 4, 5, 1])
121     [1, 2, 3, 4, 5]
122
123     """
124     return list(set(lst))
125
126
127 def uniq(lst: List[Any]) -> List[Any]:
128     """
129     Alias for dedup_list.
130     """
131     return dedup_list(lst)
132
133
134 def contains_duplicates(lst: List[Any]) -> bool:
135     """
136     Does the list contian duplicate elements or not?
137
138     >>> lst = [1, 2, 1, 3, 3, 4, 4, 5, 6, 1, 3, 4]
139     >>> contains_duplicates(lst)
140     True
141
142     >>> contains_duplicates(dedup_list(lst))
143     False
144
145     """
146     seen = set()
147     for _ in lst:
148         if _ in seen:
149             return True
150         seen.add(_)
151     return False
152
153
154 def all_unique(lst: List[Any]) -> bool:
155     """
156     Inverted alias for contains_duplicates.
157     """
158     return not contains_duplicates(lst)
159
160
161 def transpose(lst: List[Any]) -> List[Any]:
162     """
163     Transpose a list of lists.
164
165     >>> lst = [[1, 2], [3, 4], [5, 6]]
166     >>> transpose(lst)
167     [[1, 3, 5], [2, 4, 6]]
168
169     """
170     transposed = zip(*lst)
171     return [list(_) for _ in transposed]
172
173
174 def ngrams(lst: Sequence[Any], n):
175     """
176     Return the ngrams in the sequence.
177
178     >>> seq = 'encyclopedia'
179     >>> for _ in ngrams(seq, 3):
180     ...     _
181     'enc'
182     'ncy'
183     'cyc'
184     'ycl'
185     'clo'
186     'lop'
187     'ope'
188     'ped'
189     'edi'
190     'dia'
191
192     >>> seq = ['this', 'is', 'an', 'awesome', 'test']
193     >>> for _ in ngrams(seq, 3):
194     ...     _
195     ['this', 'is', 'an']
196     ['is', 'an', 'awesome']
197     ['an', 'awesome', 'test']
198     """
199     for i in range(len(lst) - n + 1):
200         yield lst[i : i + n]
201
202
203 def permute(seq: str):
204     """
205     Returns all permutations of a sequence; takes O(N!) time.
206
207     >>> for x in permute('cat'):
208     ...     print(x)
209     cat
210     cta
211     act
212     atc
213     tca
214     tac
215
216     """
217     yield from _permute(seq, "")
218
219
220 def _permute(seq: str, path: str):
221     seq_len = len(seq)
222     if seq_len == 0:
223         yield path
224
225     for i in range(seq_len):
226         car = seq[i]
227         left = seq[0:i]
228         right = seq[i + 1 :]
229         cdr = left + right
230         yield from _permute(cdr, path + car)
231
232
233 def binary_search(
234     lst: Sequence[Any], target: Any, *, sanity_check=False
235 ) -> Tuple[bool, int]:
236     """Performs a binary search on lst (which must already be sorted).
237     Returns a Tuple composed of a bool which indicates whether the
238     target was found and an int which indicates the index closest to
239     target whether it was found or not.
240
241     >>> a = [1, 4, 5, 6, 7, 9, 10, 11]
242     >>> binary_search(a, 4)
243     (True, 1)
244
245     >>> binary_search(a, 12)
246     (False, 8)
247
248     >>> binary_search(a, 3)
249     (False, 1)
250
251     >>> binary_search(a, 2)
252     (False, 1)
253
254     >>> a.append(9)
255     >>> binary_search(a, 4, sanity_check=True)
256     Traceback (most recent call last):
257     ...
258     AssertionError
259
260     """
261     if sanity_check:
262         last = None
263         for x in lst:
264             if last is not None:
265                 assert x >= last  # This asserts iff the list isn't sorted
266             last = x  # in ascending order.
267     return _binary_search(lst, target, 0, len(lst) - 1)
268
269
270 def _binary_search(
271     lst: Sequence[Any], target: Any, low: int, high: int
272 ) -> Tuple[bool, int]:
273     if high >= low:
274         mid = (high + low) // 2
275         if lst[mid] == target:
276             return (True, mid)
277         elif lst[mid] > target:
278             return _binary_search(lst, target, low, mid - 1)
279         else:
280             return _binary_search(lst, target, mid + 1, high)
281     else:
282         return (False, low)
283
284
285 if __name__ == '__main__':
286     import doctest
287
288     doctest.testmod()