Easier and more self documenting patterns for loading/saving Persistent
[python_utils.git] / list_utils.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """Some useful(?) utilities for dealing with Lists."""
6
7 import random
8 from collections import Counter
9 from itertools import chain, combinations, islice
10 from typing import Any, Iterator, List, MutableSequence, Sequence, Tuple
11
12
13 def shard(lst: List[Any], size: int) -> Iterator[Any]:
14     """
15     Yield successive size-sized shards from lst.
16
17     >>> for sublist in shard([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 3):
18     ...     [_ for _ in sublist]
19     [1, 2, 3]
20     [4, 5, 6]
21     [7, 8, 9]
22     [10, 11, 12]
23
24     """
25     for x in range(0, len(lst), size):
26         yield islice(lst, x, x + size)
27
28
29 def flatten(lst: List[Any]) -> List[Any]:
30     """
31     Flatten out a list:
32
33     >>> flatten([ 1, [2, 3, 4, [5], 6], 7, [8, [9]]])
34     [1, 2, 3, 4, 5, 6, 7, 8, 9]
35
36     """
37     if len(lst) == 0:
38         return lst
39     if isinstance(lst[0], list):
40         return flatten(lst[0]) + flatten(lst[1:])
41     return lst[:1] + flatten(lst[1:])
42
43
44 def prepend(item: Any, lst: List[Any]) -> List[Any]:
45     """
46     Prepend an item to a list.
47
48     >>> prepend('foo', ['bar', 'baz'])
49     ['foo', 'bar', 'baz']
50
51     """
52     lst.insert(0, item)
53     return lst
54
55
56 def remove_list_if_one_element(lst: List[Any]) -> Any:
57     """
58     Remove the list and return the 0th element iff its length is one.
59
60     >>> remove_list_if_one_element([1234])
61     1234
62
63     >>> remove_list_if_one_element([1, 2, 3, 4])
64     [1, 2, 3, 4]
65
66     """
67     if len(lst) == 1:
68         return lst[0]
69     else:
70         return lst
71
72
73 def population_counts(lst: Sequence[Any]) -> Counter:
74     """
75     Return a population count mapping for the list (i.e. the keys are
76     list items and the values are the number of occurrances of that
77     list item in the original list.
78
79     >>> population_counts([1, 1, 1, 2, 2, 3, 3, 3, 4])
80     Counter({1: 3, 3: 3, 2: 2, 4: 1})
81
82     """
83     return Counter(lst)
84
85
86 def most_common(lst: List[Any], *, count=1) -> Any:
87
88     """
89     Return the most common item in the list.  In the case of ties,
90     which most common item is returned will be random.
91
92     >>> most_common([1, 1, 1, 2, 2, 3, 3, 3, 3, 4, 4])
93     3
94
95     >>> most_common([1, 1, 1, 2, 2, 3, 3, 3, 3, 4, 4], count=2)
96     [3, 1]
97
98     """
99     p = population_counts(lst)
100     return remove_list_if_one_element([_[0] for _ in p.most_common()[0:count]])
101
102
103 def least_common(lst: List[Any], *, count=1) -> Any:
104     """
105     Return the least common item in the list.  In the case of
106     ties, which least common item is returned will be random.
107
108     >>> least_common([1, 1, 1, 2, 2, 3, 3, 3, 4])
109     4
110
111     >>> least_common([1, 1, 1, 2, 2, 3, 3, 3, 4], count=2)
112     [4, 2]
113
114     """
115     p = population_counts(lst)
116     mc = p.most_common()[-count:]
117     mc.reverse()
118     return remove_list_if_one_element([_[0] for _ in mc])
119
120
121 def dedup_list(lst: List[Any]) -> List[Any]:
122     """
123     Remove duplicates from the list performantly.
124
125     >>> dedup_list([1, 2, 1, 3, 3, 4, 2, 3, 4, 5, 1])
126     [1, 2, 3, 4, 5]
127
128     """
129     return list(set(lst))
130
131
132 def uniq(lst: List[Any]) -> List[Any]:
133     """
134     Alias for dedup_list.
135     """
136     return dedup_list(lst)
137
138
139 def contains_duplicates(lst: List[Any]) -> bool:
140     """
141     Does the list contian duplicate elements or not?
142
143     >>> lst = [1, 2, 1, 3, 3, 4, 4, 5, 6, 1, 3, 4]
144     >>> contains_duplicates(lst)
145     True
146
147     >>> contains_duplicates(dedup_list(lst))
148     False
149
150     """
151     seen = set()
152     for _ in lst:
153         if _ in seen:
154             return True
155         seen.add(_)
156     return False
157
158
159 def all_unique(lst: List[Any]) -> bool:
160     """
161     Inverted alias for contains_duplicates.
162     """
163     return not contains_duplicates(lst)
164
165
166 def transpose(lst: List[Any]) -> List[Any]:
167     """
168     Transpose a list of lists.
169
170     >>> lst = [[1, 2], [3, 4], [5, 6]]
171     >>> transpose(lst)
172     [[1, 3, 5], [2, 4, 6]]
173
174     """
175     transposed = zip(*lst)
176     return [list(_) for _ in transposed]
177
178
179 def ngrams(lst: Sequence[Any], n):
180     """
181     Return the ngrams in the sequence.
182
183     >>> seq = 'encyclopedia'
184     >>> for _ in ngrams(seq, 3):
185     ...     _
186     'enc'
187     'ncy'
188     'cyc'
189     'ycl'
190     'clo'
191     'lop'
192     'ope'
193     'ped'
194     'edi'
195     'dia'
196
197     >>> seq = ['this', 'is', 'an', 'awesome', 'test']
198     >>> for _ in ngrams(seq, 3):
199     ...     _
200     ['this', 'is', 'an']
201     ['is', 'an', 'awesome']
202     ['an', 'awesome', 'test']
203     """
204     for i in range(len(lst) - n + 1):
205         yield lst[i : i + n]
206
207
208 def permute(seq: str):
209     """
210     Returns all permutations of a sequence; takes O(N!) time.
211
212     >>> for x in permute('cat'):
213     ...     print(x)
214     cat
215     cta
216     act
217     atc
218     tca
219     tac
220
221     """
222     yield from _permute(seq, "")
223
224
225 def _permute(seq: str, path: str):
226     seq_len = len(seq)
227     if seq_len == 0:
228         yield path
229
230     for i in range(seq_len):
231         car = seq[i]
232         left = seq[0:i]
233         right = seq[i + 1 :]
234         cdr = left + right
235         yield from _permute(cdr, path + car)
236
237
238 def shuffle(seq: MutableSequence[Any]) -> MutableSequence[Any]:
239     """Shuffles a sequence into a random order.
240
241     >>> random.seed(22)
242     >>> shuffle([1, 2, 3, 4, 5])
243     [3, 4, 1, 5, 2]
244
245     >>> shuffle('example')
246     'empaelx'
247
248     """
249     if isinstance(seq, str):
250         import string_utils
251
252         return string_utils.shuffle(seq)
253     else:
254         random.shuffle(seq)
255         return seq
256
257
258 def scramble(seq: MutableSequence[Any]) -> MutableSequence[Any]:
259     return shuffle(seq)
260
261
262 def binary_search(lst: Sequence[Any], target: Any) -> Tuple[bool, int]:
263     """Performs a binary search on lst (which must already be sorted).
264     Returns a Tuple composed of a bool which indicates whether the
265     target was found and an int which indicates the index closest to
266     target whether it was found or not.
267
268     >>> a = [1, 4, 5, 6, 7, 9, 10, 11]
269     >>> binary_search(a, 4)
270     (True, 1)
271
272     >>> binary_search(a, 12)
273     (False, 8)
274
275     >>> binary_search(a, 3)
276     (False, 1)
277
278     >>> binary_search(a, 2)
279     (False, 1)
280
281     >>> a.append(9)
282     >>> binary_search(a, 4, sanity_check=True)
283     Traceback (most recent call last):
284     ...
285     AssertionError
286
287     """
288     if __debug__:
289         last = None
290         for x in lst:
291             if last is not None:
292                 assert x >= last  # This asserts iff the list isn't sorted
293             last = x  # in ascending order.
294     return _binary_search(lst, target, 0, len(lst) - 1)
295
296
297 def _binary_search(lst: Sequence[Any], target: Any, low: int, high: int) -> Tuple[bool, int]:
298     if high >= low:
299         mid = (high + low) // 2
300         if lst[mid] == target:
301             return (True, mid)
302         elif lst[mid] > target:
303             return _binary_search(lst, target, low, mid - 1)
304         else:
305             return _binary_search(lst, target, mid + 1, high)
306     else:
307         return (False, low)
308
309
310 def powerset(lst: Sequence[Any]) -> Iterator[Sequence[Any]]:
311     """Returns the powerset of the items in the input sequence.
312
313     >>> for x in powerset([1, 2, 3]):
314     ...     print(x)
315     ()
316     (1,)
317     (2,)
318     (3,)
319     (1, 2)
320     (1, 3)
321     (2, 3)
322     (1, 2, 3)
323     """
324     return chain.from_iterable(combinations(lst, r) for r in range(len(lst) + 1))
325
326
327 if __name__ == '__main__':
328     import doctest
329
330     doctest.testmod()