More messing with the project file.
[pyutils.git] / src / pyutils / list_utils.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """This module contains helper functions for dealing with Python lists."""
6
7 import random
8 from collections import Counter
9 from itertools import chain, combinations, islice
10 from typing import Any, Iterator, List, MutableSequence, Sequence, Tuple
11
12
13 def shard(lst: List[Any], size: int) -> Iterator[Any]:
14     """
15     Shards (i.e. splits) a list into sublists of size `size` whcih,
16     together, contain all items in the original unsharded list.
17
18     Args:
19         lst: the original input list to shard
20         size: the ideal shard size (number of elements per shard)
21
22     Returns:
23         A generator that yields successive shards.
24
25     .. note::
26
27         If `len(lst)` is not an even multiple of `size` then the last
28         shard will not have `size` items in it.  It will have
29         `len(lst) % size` items instead.
30
31     >>> for sublist in shard([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 3):
32     ...     [_ for _ in sublist]
33     [1, 2, 3]
34     [4, 5, 6]
35     [7, 8, 9]
36     [10, 11, 12]
37     """
38     for x in range(0, len(lst), size):
39         yield islice(lst, x, x + size)
40
41
42 def flatten(lst: List[Any]) -> List[Any]:
43     """
44     Flatten out a list.  That is, for each item in list that contains
45     a list, remove the nested list and replace it with its items.
46
47     Args:
48         lst: the list to flatten
49
50     Returns:
51         The flattened list.  See example.
52
53     >>> flatten([ 1, [2, 3, 4, [5], 6], 7, [8, [9]]])
54     [1, 2, 3, 4, 5, 6, 7, 8, 9]
55     """
56     if len(lst) == 0:
57         return lst
58     if isinstance(lst[0], list):
59         return flatten(lst[0]) + flatten(lst[1:])
60     return lst[:1] + flatten(lst[1:])
61
62
63 def prepend(item: Any, lst: List[Any]) -> List[Any]:
64     """
65     Prepend an item to a list.  An alias for `list.insert(0, item)`.
66     The opposite of `list.append()`.
67
68     Args:
69         item: the item to be prepended
70         lst: the list on which to prepend
71
72     Returns:
73         The list with item prepended.
74
75     >>> prepend('foo', ['bar', 'baz'])
76     ['foo', 'bar', 'baz']
77     """
78     lst.insert(0, item)
79     return lst
80
81
82 def remove_list_if_one_element(lst: List[Any]) -> Any:
83     """
84     Remove the list and return the 0th element iff its length is one.
85
86     Args:
87         lst: the List to check
88
89     Returns:
90         Either `lst` (if `len(lst) > 1`) or `lst[0]` (if `len(lst) == 1`).
91
92     >>> remove_list_if_one_element([1234])
93     1234
94
95     >>> remove_list_if_one_element([1, 2, 3, 4])
96     [1, 2, 3, 4]
97     """
98     if len(lst) == 1:
99         return lst[0]
100     else:
101         return lst
102
103
104 def population_counts(lst: Sequence[Any]) -> Counter:
105     """
106     Return a population count mapping for the list (i.e. the keys are
107     list items and the values are the number of occurrances of that
108     list item in the original list).  Note: this is used internally
109     to implement :meth:`most_common` and :meth:`least_common`.
110
111     Args:
112         lst: the list whose population should be counted
113
114     Returns:
115         a `Counter` containing the population count of `lst` items.
116
117     >>> population_counts([1, 1, 1, 2, 2, 3, 3, 3, 4])
118     Counter({1: 3, 3: 3, 2: 2, 4: 1})
119     """
120     return Counter(lst)
121
122
123 def most_common(lst: List[Any], *, count=1) -> Any:
124     """
125     Return the N most common item in the list.
126
127     Args:
128         lst: the list to find the most common item in
129         count: the number of most common items to return
130
131     Returns:
132         The most common item in `lst`.
133
134     .. warning::
135
136         In the case of ties for most common item, which most common
137         item is returned is undefined.
138
139     >>> most_common([1, 1, 1, 2, 2, 3, 3, 3, 3, 4, 4])
140     3
141
142     >>> most_common([1, 1, 1, 2, 2, 3, 3, 3, 3, 4, 4], count=2)
143     [3, 1]
144
145     """
146     p = population_counts(lst)
147     return remove_list_if_one_element([_[0] for _ in p.most_common()[0:count]])
148
149
150 def least_common(lst: List[Any], *, count=1) -> Any:
151     """
152     Return the N least common item in the list.
153
154     Args:
155         lst: the list to find the least common item in
156         count: the number of least common items to return
157
158     Returns:
159         The least common item in `lst`
160
161     .. warning::
162
163        In the case of ties, which least common item is returned
164        is undefined.
165
166     >>> least_common([1, 1, 1, 2, 2, 3, 3, 3, 4])
167     4
168
169     >>> least_common([1, 1, 1, 2, 2, 3, 3, 3, 4], count=2)
170     [4, 2]
171     """
172     p = population_counts(lst)
173     mc = p.most_common()[-count:]
174     mc.reverse()
175     return remove_list_if_one_element([_[0] for _ in mc])
176
177
178 def dedup_list(lst: List[Any]) -> List[Any]:
179     """
180     Remove duplicates from the list.
181
182     Args:
183         lst: the list to de-duplicate
184
185     Returns:
186         The de-duplicated input list.  That is, the same list with
187         all extra duplicate items removed.  The list composed of
188         the set of unique items from the input `lst`
189
190     >>> dedup_list([1, 2, 1, 3, 3, 4, 2, 3, 4, 5, 1])
191     [1, 2, 3, 4, 5]
192     """
193     return list(set(lst))
194
195
196 def uniq(lst: List[Any]) -> List[Any]:
197     """
198     Alias for :meth:`dedup_list`.
199     """
200     return dedup_list(lst)
201
202
203 def contains_duplicates(lst: List[Any]) -> bool:
204     """
205     Does the list contain duplicate elements or not?
206
207     Args:
208         lst: the list to check for duplicates
209
210     Returns:
211         True if the input `lst` contains duplicated items and
212         False otherwise.
213
214     >>> lst = [1, 2, 1, 3, 3, 4, 4, 5, 6, 1, 3, 4]
215     >>> contains_duplicates(lst)
216     True
217
218     >>> contains_duplicates(dedup_list(lst))
219     False
220     """
221     seen = set()
222     for _ in lst:
223         if _ in seen:
224             return True
225         seen.add(_)
226     return False
227
228
229 def all_unique(lst: List[Any]) -> bool:
230     """
231     Inverted alias for :meth:`contains_duplicates`.
232     """
233     return not contains_duplicates(lst)
234
235
236 def transpose(lst: List[Any]) -> List[Any]:
237     """
238     Transpose a list of lists.
239
240     Args:
241         lst: the list of lists to be transposed.
242
243     Returns:
244         The transposed result.  See example.
245
246     >>> lst = [[1, 2], [3, 4], [5, 6]]
247     >>> transpose(lst)
248     [[1, 3, 5], [2, 4, 6]]
249
250     """
251     transposed = zip(*lst)
252     return [list(_) for _ in transposed]
253
254
255 def ngrams(lst: Sequence[Any], n: int):
256     """
257     Return the ngrams in the sequence.
258
259     Args:
260         lst: the list in which to find ngrams
261         n: the size of each ngram to return
262
263     Returns:
264         A generator that yields all ngrams of size `n` in `lst`.
265
266     >>> seq = 'encyclopedia'
267     >>> for _ in ngrams(seq, 3):
268     ...     _
269     'enc'
270     'ncy'
271     'cyc'
272     'ycl'
273     'clo'
274     'lop'
275     'ope'
276     'ped'
277     'edi'
278     'dia'
279
280     >>> seq = ['this', 'is', 'an', 'awesome', 'test']
281     >>> for _ in ngrams(seq, 3):
282     ...     _
283     ['this', 'is', 'an']
284     ['is', 'an', 'awesome']
285     ['an', 'awesome', 'test']
286     """
287     for i in range(len(lst) - n + 1):
288         yield lst[i : i + n]
289
290
291 def permute(seq: str):
292     """
293     Returns all permutations of a sequence.
294
295     Args:
296         seq: the sequence to permute
297
298     Returns:
299         All permutations creatable by shuffling items in `seq`.
300
301     .. warning::
302
303         Takes O(N!) time, beware of large inputs.
304
305     >>> for x in permute('cat'):
306     ...     print(x)
307     cat
308     cta
309     act
310     atc
311     tca
312     tac
313     """
314     yield from _permute(seq, "")
315
316
317 def _permute(seq: str, path: str):
318     """Internal helper to permute items recursively."""
319     seq_len = len(seq)
320     if seq_len == 0:
321         yield path
322
323     for i in range(seq_len):
324         car = seq[i]
325         left = seq[0:i]
326         right = seq[i + 1 :]
327         cdr = left + right
328         yield from _permute(cdr, path + car)
329
330
331 def shuffle(seq: MutableSequence[Any]) -> MutableSequence[Any]:
332     """Shuffles a sequence into a random order.
333
334     Args:
335         seq: a sequence to shuffle
336
337     Returns:
338         The shuffled sequence.
339
340     >>> random.seed(22)
341     >>> shuffle([1, 2, 3, 4, 5])
342     [3, 4, 1, 5, 2]
343
344     >>> shuffle('example')
345     'empaelx'
346     """
347     if isinstance(seq, str):
348         from pyutils import string_utils
349
350         return string_utils.shuffle(seq)
351     else:
352         random.shuffle(seq)
353         return seq
354
355
356 def scramble(seq: MutableSequence[Any]) -> MutableSequence[Any]:
357     """An alias for :meth:`shuffle`."""
358     return shuffle(seq)
359
360
361 def binary_search(lst: Sequence[Any], target: Any) -> Tuple[bool, int]:
362     """Performs a binary search on lst (which must already be sorted).
363
364     Args:
365         lst: the (already sorted!) list in which to search
366         target: the item value to be found
367
368     Returns:
369         A Tuple composed of a bool which indicates whether the
370         target was found and an int which indicates the index closest to
371         target whether it was found or not.
372
373     >>> a = [1, 4, 5, 6, 7, 9, 10, 11]
374     >>> binary_search(a, 4)
375     (True, 1)
376
377     >>> binary_search(a, 12)
378     (False, 8)
379
380     >>> binary_search(a, 3)
381     (False, 1)
382
383     >>> binary_search(a, 2)
384     (False, 1)
385
386     >>> a.append(9)
387     >>> binary_search(a, 4)
388     Traceback (most recent call last):
389     ...
390     AssertionError
391
392     """
393     if __debug__:
394         last = None
395         for x in lst:
396             if last is not None:
397                 assert x >= last  # This asserts iff the list isn't sorted
398             last = x  # in ascending order.
399     return _binary_search(lst, target, 0, len(lst) - 1)
400
401
402 def _binary_search(
403     lst: Sequence[Any], target: Any, low: int, high: int
404 ) -> Tuple[bool, int]:
405     """Internal helper to perform a binary search recursively."""
406     if high >= low:
407         mid = (high + low) // 2
408         if lst[mid] == target:
409             return (True, mid)
410         elif lst[mid] > target:
411             return _binary_search(lst, target, low, mid - 1)
412         else:
413             return _binary_search(lst, target, mid + 1, high)
414     else:
415         return (False, low)
416
417
418 def powerset(seq: Sequence[Any]) -> Iterator[Sequence[Any]]:
419     """Returns the powerset of the items in the input sequence.  That is,
420     return the set containing every set constructable using items from
421     seq (including the empty set and the "full" set: `seq` itself).
422
423     Args:
424         seq: the sequence whose items will be used to construct the powerset.
425
426     Returns:
427         The powerset composed of all sets possible to create with items from `seq`.
428         See: https://en.wikipedia.org/wiki/Power_set.
429
430     >>> for x in powerset([1, 2, 3]):
431     ...     print(x)
432     ()
433     (1,)
434     (2,)
435     (3,)
436     (1, 2)
437     (1, 3)
438     (2, 3)
439     (1, 2, 3)
440     """
441     return chain.from_iterable(combinations(seq, r) for r in range(len(seq) + 1))
442
443
444 if __name__ == '__main__':
445     import doctest
446
447     doctest.testmod()