A bunch of changes...
[python_utils.git] / list_utils.py
1 #!/usr/bin/env python3
2
3 from collections import Counter
4 from itertools import islice
5 from typing import Any, Iterator, List, Mapping, Sequence, Tuple
6
7
8 def shard(lst: List[Any], size: int) -> Iterator[Any]:
9     """
10     Yield successive size-sized shards from lst.
11
12     >>> for sublist in shard([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 3):
13     ...     [_ for _ in sublist]
14     [1, 2, 3]
15     [4, 5, 6]
16     [7, 8, 9]
17     [10, 11, 12]
18
19     """
20     for x in range(0, len(lst), size):
21         yield islice(lst, x, x + size)
22
23
24 def flatten(lst: List[Any]) -> List[Any]:
25     """
26     Flatten out a list:
27
28     >>> flatten([ 1, [2, 3, 4, [5], 6], 7, [8, [9]]])
29     [1, 2, 3, 4, 5, 6, 7, 8, 9]
30
31     """
32     if len(lst) == 0:
33         return lst
34     if isinstance(lst[0], list):
35         return flatten(lst[0]) + flatten(lst[1:])
36     return lst[:1] + flatten(lst[1:])
37
38
39 def prepend(item: Any, lst: List[Any]) -> List[Any]:
40     """
41     Prepend an item to a list.
42
43     >>> prepend('foo', ['bar', 'baz'])
44     ['foo', 'bar', 'baz']
45
46     """
47     lst.insert(0, item)
48     return lst
49
50
51 def remove_list_if_one_element(lst: List[Any]) -> Any:
52     """
53     Remove the list and return the 0th element iff its length is one.
54
55     >>> remove_list_if_one_element([1234])
56     1234
57
58     >>> remove_list_if_one_element([1, 2, 3, 4])
59     [1, 2, 3, 4]
60
61     """
62     if len(lst) == 1:
63         return lst[0]
64     else:
65         return lst
66
67
68 def population_counts(lst: List[Any]) -> Mapping[Any, int]:
69     """
70     Return a population count mapping for the list (i.e. the keys are
71     list items and the values are the number of occurrances of that
72     list item in the original list.
73
74     >>> population_counts([1, 1, 1, 2, 2, 3, 3, 3, 4])
75     Counter({1: 3, 3: 3, 2: 2, 4: 1})
76
77     """
78     return Counter(lst)
79
80
81 def most_common(lst: List[Any], *, count=1) -> Any:
82
83     """
84     Return the most common item in the list.  In the case of ties,
85     which most common item is returned will be random.
86
87     >>> most_common([1, 1, 1, 2, 2, 3, 3, 3, 3, 4, 4])
88     3
89
90     >>> most_common([1, 1, 1, 2, 2, 3, 3, 3, 3, 4, 4], count=2)
91     [3, 1]
92
93     """
94     p = population_counts(lst)
95     return remove_list_if_one_element([_[0] for _ in p.most_common()[0:count]])
96
97
98 def least_common(lst: List[Any], *, count=1) -> Any:
99     """
100     Return the least common item in the list.  In the case of
101     ties, which least common item is returned will be random.
102
103     >>> least_common([1, 1, 1, 2, 2, 3, 3, 3, 4])
104     4
105
106     >>> least_common([1, 1, 1, 2, 2, 3, 3, 3, 4], count=2)
107     [4, 2]
108
109     """
110     p = population_counts(lst)
111     mc = p.most_common()[-count:]
112     mc.reverse()
113     return remove_list_if_one_element([_[0] for _ in mc])
114
115
116 def dedup_list(lst: List[Any]) -> List[Any]:
117     """
118     Remove duplicates from the list performantly.
119
120     >>> dedup_list([1, 2, 1, 3, 3, 4, 2, 3, 4, 5, 1])
121     [1, 2, 3, 4, 5]
122
123     """
124     return list(set(lst))
125
126
127 def uniq(lst: List[Any]) -> List[Any]:
128     """
129     Alias for dedup_list.
130     """
131     return dedup_list(lst)
132
133
134 def contains_duplicates(lst: List[Any]) -> bool:
135     """
136     Does the list contian duplicate elements or not?
137
138     >>> lst = [1, 2, 1, 3, 3, 4, 4, 5, 6, 1, 3, 4]
139     >>> contains_duplicates(lst)
140     True
141
142     >>> contains_duplicates(dedup_list(lst))
143     False
144
145     """
146     seen = set()
147     for _ in lst:
148         if _ in seen:
149             return True
150         seen.add(_)
151     return False
152
153
154 def all_unique(lst: List[Any]) -> bool:
155     """
156     Inverted alias for contains_duplicates.
157     """
158     return not contains_duplicates(lst)
159
160
161 def transpose(lst: List[Any]) -> List[Any]:
162     """
163     Transpose a list of lists.
164
165     >>> lst = [[1, 2], [3, 4], [5, 6]]
166     >>> transpose(lst)
167     [[1, 3, 5], [2, 4, 6]]
168
169     """
170     transposed = zip(*lst)
171     return [list(_) for _ in transposed]
172
173
174 def ngrams(lst: Sequence[Any], n):
175     """
176     Return the ngrams in the sequence.
177
178     >>> seq = 'encyclopedia'
179     >>> for _ in ngrams(seq, 3):
180     ...     _
181     'enc'
182     'ncy'
183     'cyc'
184     'ycl'
185     'clo'
186     'lop'
187     'ope'
188     'ped'
189     'edi'
190     'dia'
191
192     >>> seq = ['this', 'is', 'an', 'awesome', 'test']
193     >>> for _ in ngrams(seq, 3):
194     ...     _
195     ['this', 'is', 'an']
196     ['is', 'an', 'awesome']
197     ['an', 'awesome', 'test']
198     """
199     for i in range(len(lst) - n + 1):
200         yield lst[i:i + n]
201
202
203 def permute(seq: Sequence[Any]):
204     """
205     Returns all permutations of a sequence; takes O(N^2) time.
206
207     >>> for x in permute('cat'):
208     ...     print(x)
209     cat
210     cta
211     act
212     atc
213     tca
214     tac
215
216     """
217     yield from _permute(seq, "")
218
219 def _permute(seq: Sequence[Any], path):
220     if len(seq) == 0:
221         yield path
222
223     for i in range(len(seq)):
224         car = seq[i]
225         left = seq[0:i]
226         right = seq[i + 1:]
227         cdr = left + right
228         yield from _permute(cdr, path + car)
229
230
231 def binary_search(lst: Sequence[Any], target:Any) -> Tuple[bool, int]:
232     """Performs a binary search on lst (which must already be sorted).
233     Returns a Tuple composed of a bool which indicates whether the
234     target was found and an int which indicates the index closest to
235     target whether it was found or not.
236
237     >>> a = [1, 4, 5, 6, 7, 9, 10, 11]
238     >>> binary_search(a, 4)
239     (True, 1)
240
241     >>> binary_search(a, 12)
242     (False, 8)
243
244     >>> binary_search(a, 3)
245     (False, 1)
246
247     >>> binary_search(a, 2)
248     (False, 1)
249
250     """
251     return _binary_search(lst, target, 0, len(lst) - 1)
252
253
254 def _binary_search(lst: Sequence[Any], target: Any, low: int, high: int) -> Tuple[bool, int]:
255     if high >= low:
256         mid = (high + low) // 2
257         if lst[mid] == target:
258             return (True, mid)
259         elif lst[mid] > target:
260             return _binary_search(lst, target, low, mid - 1)
261         else:
262             return _binary_search(lst, target, mid + 1, high)
263     else:
264         return (False, low)
265
266
267 if __name__ == '__main__':
268     import doctest
269     doctest.testmod()