Fix typo.
[pyutils.git] / src / pyutils / list_utils.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2023, Scott Gasch
4
5 """This module contains helper functions for dealing with Python lists."""
6
7 import random
8 from collections import Counter
9 from itertools import chain, combinations, islice
10 from typing import (
11     Any,
12     Generator,
13     Iterator,
14     List,
15     MutableSequence,
16     Sequence,
17     Tuple,
18     TypeVar,
19 )
20
21
22 def shard(lst: List[Any], size: int) -> Iterator[Any]:
23     """
24     Shards (i.e. splits) a list into sublists of size `size` whcih,
25     together, contain all items in the original unsharded list.
26
27     Args:
28         lst: the original input list to shard
29         size: the ideal shard size (number of elements per shard)
30
31     Returns:
32         A generator that yields successive shards.
33
34     .. note::
35
36         If `len(lst)` is not an even multiple of `size` then the last
37         shard will not have `size` items in it.  It will have
38         `len(lst) % size` items instead.
39
40     >>> for sublist in shard([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 3):
41     ...     [_ for _ in sublist]
42     [1, 2, 3]
43     [4, 5, 6]
44     [7, 8, 9]
45     [10, 11, 12]
46     """
47     for x in range(0, len(lst), size):
48         yield islice(lst, x, x + size)
49
50
51 def flatten(lst: List[Any]) -> List[Any]:
52     """
53     Flatten out a list.  That is, for each item in list that contains
54     a list, remove the nested list and replace it with its items.
55
56     Args:
57         lst: the list to flatten
58
59     Returns:
60         The flattened list.  See example.
61
62     >>> flatten([ 1, [2, 3, 4, [5], 6], 7, [8, [9]]])
63     [1, 2, 3, 4, 5, 6, 7, 8, 9]
64     """
65     if len(lst) == 0:
66         return lst
67     if isinstance(lst[0], list):
68         return flatten(lst[0]) + flatten(lst[1:])
69     return lst[:1] + flatten(lst[1:])
70
71
72 def prepend(item: Any, lst: List[Any]) -> List[Any]:
73     """
74     Prepend an item to a list.  An alias for `list.insert(0, item)`.
75     The opposite of `list.append()`.
76
77     Args:
78         item: the item to be prepended
79         lst: the list on which to prepend
80
81     Returns:
82         The list with item prepended.
83
84     >>> prepend('foo', ['bar', 'baz'])
85     ['foo', 'bar', 'baz']
86     """
87     lst.insert(0, item)
88     return lst
89
90
91 def remove_list_if_one_element(lst: List[Any]) -> Any:
92     """
93     Remove the list and return the 0th element iff its length is one.
94
95     Args:
96         lst: the List to check
97
98     Returns:
99         Either `lst` (if `len(lst) > 1`) or `lst[0]` (if `len(lst) == 1`).
100
101     >>> remove_list_if_one_element([1234])
102     1234
103
104     >>> remove_list_if_one_element([1, 2, 3, 4])
105     [1, 2, 3, 4]
106     """
107     if len(lst) == 1:
108         return lst[0]
109     else:
110         return lst
111
112
113 def population_counts(lst: Sequence[Any]) -> Counter:
114     """
115     Return a population count mapping for the list (i.e. the keys are
116     list items and the values are the number of occurrances of that
117     list item in the original list).  Note: this is used internally
118     to implement :meth:`most_common` and :meth:`least_common`.
119
120     Args:
121         lst: the list whose population should be counted
122
123     Returns:
124         a `Counter` containing the population count of `lst` items.
125
126     >>> population_counts([1, 1, 1, 2, 2, 3, 3, 3, 4])
127     Counter({1: 3, 3: 3, 2: 2, 4: 1})
128     """
129     return Counter(lst)
130
131
132 def most_common(lst: List[Any], *, count: int = 1) -> Any:
133     """
134     Return the N most common item in the list.
135
136     Args:
137         lst: the list to find the most common item in
138         count: the number of most common items to return
139
140     Returns:
141         The most common item in `lst`.
142
143     .. warning::
144
145         In the case of ties for most common item, which most common
146         item is returned is undefined.
147
148     >>> most_common([1, 1, 1, 2, 2, 3, 3, 3, 3, 4, 4])
149     3
150
151     >>> most_common([1, 1, 1, 2, 2, 3, 3, 3, 3, 4, 4], count=2)
152     [3, 1]
153
154     """
155     p = population_counts(lst)
156     return remove_list_if_one_element([_[0] for _ in p.most_common()[0:count]])
157
158
159 def least_common(lst: List[Any], *, count: int = 1) -> Any:
160     """
161     Return the N least common item in the list.
162
163     Args:
164         lst: the list to find the least common item in
165         count: the number of least common items to return
166
167     Returns:
168         The least common item in `lst`
169
170     .. warning::
171
172        In the case of ties, which least common item is returned
173        is undefined.
174
175     >>> least_common([1, 1, 1, 2, 2, 3, 3, 3, 4])
176     4
177
178     >>> least_common([1, 1, 1, 2, 2, 3, 3, 3, 4], count=2)
179     [4, 2]
180     """
181     p = population_counts(lst)
182     mc = p.most_common()[-count:]
183     mc.reverse()
184     return remove_list_if_one_element([_[0] for _ in mc])
185
186
187 def dedup_list(lst: List[Any]) -> List[Any]:
188     """
189     Remove duplicates from the list.
190
191     Args:
192         lst: the list to de-duplicate
193
194     Returns:
195         The de-duplicated input list.  That is, the same list with
196         all extra duplicate items removed.  The list composed of
197         the set of unique items from the input `lst`
198
199     >>> dedup_list([1, 2, 1, 3, 3, 4, 2, 3, 4, 5, 1])
200     [1, 2, 3, 4, 5]
201     """
202     return list(set(lst))
203
204
205 def uniq(lst: List[Any]) -> List[Any]:
206     """
207     Alias for :meth:`dedup_list`.
208     """
209     return dedup_list(lst)
210
211
212 def contains_duplicates(lst: List[Any]) -> bool:
213     """
214     Does the list contain duplicate elements or not?
215
216     Args:
217         lst: the list to check for duplicates
218
219     Returns:
220         True if the input `lst` contains duplicated items and
221         False otherwise.
222
223     >>> lst = [1, 2, 1, 3, 3, 4, 4, 5, 6, 1, 3, 4]
224     >>> contains_duplicates(lst)
225     True
226
227     >>> contains_duplicates(dedup_list(lst))
228     False
229     """
230     seen = set()
231     for _ in lst:
232         if _ in seen:
233             return True
234         seen.add(_)
235     return False
236
237
238 def all_unique(lst: List[Any]) -> bool:
239     """
240     Inverted alias for :meth:`contains_duplicates`.
241     """
242     return not contains_duplicates(lst)
243
244
245 def transpose(lst: List[Any]) -> List[Any]:
246     """
247     Transpose a list of lists.
248
249     Args:
250         lst: the list of lists to be transposed.
251
252     Returns:
253         The transposed result.  See example.
254
255     >>> lst = [[1, 2], [3, 4], [5, 6]]
256     >>> transpose(lst)
257     [[1, 3, 5], [2, 4, 6]]
258
259     """
260     transposed = zip(*lst)
261     return [list(_) for _ in transposed]
262
263
264 T = TypeVar('T')
265
266
267 def ngrams(lst: Sequence[T], n: int) -> Generator[Sequence[T], T, None]:
268     """
269     Return the ngrams in the sequence.
270
271     Args:
272         lst: the list in which to find ngrams
273         n: the size of each ngram to return
274
275     Returns:
276         A generator that yields all ngrams of size `n` in `lst`.
277
278     >>> seq = 'encyclopedia'
279     >>> for _ in ngrams(seq, 3):
280     ...     _
281     'enc'
282     'ncy'
283     'cyc'
284     'ycl'
285     'clo'
286     'lop'
287     'ope'
288     'ped'
289     'edi'
290     'dia'
291
292     >>> seq = ['this', 'is', 'an', 'awesome', 'test']
293     >>> for _ in ngrams(seq, 3):
294     ...     _
295     ['this', 'is', 'an']
296     ['is', 'an', 'awesome']
297     ['an', 'awesome', 'test']
298     """
299     for i in range(len(lst) - n + 1):
300         yield lst[i : i + n]
301
302
303 def permute(seq: str) -> Generator[str, str, None]:
304     """
305     Returns all permutations of a sequence.
306
307     Args:
308         seq: the sequence to permute
309
310     Returns:
311         All permutations creatable by shuffling items in `seq`.
312
313     .. warning::
314
315         Takes O(N!) time, beware of large inputs.
316
317     >>> for x in permute('cat'):
318     ...     print(x)
319     cat
320     cta
321     act
322     atc
323     tca
324     tac
325     """
326     yield from _permute(seq, "")
327
328
329 def _permute(seq: str, path: str) -> Generator[str, str, None]:
330     """Internal helper to permute items recursively."""
331     seq_len = len(seq)
332     if seq_len == 0:
333         yield path
334
335     for i in range(seq_len):
336         car = seq[i]
337         left = seq[0:i]
338         right = seq[i + 1 :]
339         cdr = left + right
340         yield from _permute(cdr, path + car)
341
342
343 def shuffle(seq: MutableSequence[Any]) -> MutableSequence[Any]:
344     """Shuffles a sequence into a random order.
345
346     Args:
347         seq: a sequence to shuffle
348
349     Returns:
350         The shuffled sequence.
351
352     >>> random.seed(22)
353     >>> shuffle([1, 2, 3, 4, 5])
354     [3, 4, 1, 5, 2]
355
356     >>> shuffle('example')
357     'empaelx'
358     """
359     if isinstance(seq, str):
360         from pyutils import string_utils
361
362         return string_utils.shuffle(seq)
363     else:
364         random.shuffle(seq)
365         return seq
366
367
368 def scramble(seq: MutableSequence[Any]) -> MutableSequence[Any]:
369     """An alias for :meth:`shuffle`."""
370     return shuffle(seq)
371
372
373 def binary_search(lst: Sequence[Any], target: Any) -> Tuple[bool, int]:
374     """Performs a binary search on lst (which must already be sorted).
375
376     Args:
377         lst: the (already sorted!) list in which to search
378         target: the item value to be found
379
380     Returns:
381         A Tuple composed of a bool which indicates whether the
382         target was found and an int which indicates the index closest to
383         target whether it was found or not.
384
385     >>> a = [1, 4, 5, 6, 7, 9, 10, 11]
386     >>> binary_search(a, 4)
387     (True, 1)
388
389     >>> binary_search(a, 12)
390     (False, 8)
391
392     >>> binary_search(a, 3)
393     (False, 1)
394
395     >>> binary_search(a, 2)
396     (False, 1)
397
398     >>> a.append(9)
399     >>> binary_search(a, 4)
400     Traceback (most recent call last):
401     ...
402     AssertionError
403
404     """
405     if __debug__:
406         last = None
407         for x in lst:
408             if last is not None:
409                 assert x >= last  # This asserts iff the list isn't sorted
410             last = x  # in ascending order.
411     return _binary_search(lst, target, 0, len(lst) - 1)
412
413
414 def _binary_search(
415     lst: Sequence[Any], target: Any, low: int, high: int
416 ) -> Tuple[bool, int]:
417     """Internal helper to perform a binary search recursively."""
418     if high >= low:
419         mid = (high + low) // 2
420         if lst[mid] == target:
421             return (True, mid)
422         elif lst[mid] > target:
423             return _binary_search(lst, target, low, mid - 1)
424         else:
425             return _binary_search(lst, target, mid + 1, high)
426     else:
427         return (False, low)
428
429
430 def powerset(seq: Sequence[Any]) -> Iterator[Sequence[Any]]:
431     """Returns the powerset of the items in the input sequence.  That is,
432     return the set containing every set constructable using items from
433     seq (including the empty set and the "full" set: `seq` itself).
434
435     Args:
436         seq: the sequence whose items will be used to construct the powerset.
437
438     Returns:
439         The powerset composed of all sets possible to create with items from `seq`.
440         See: https://en.wikipedia.org/wiki/Power_set.
441
442     >>> for x in powerset([1, 2, 3]):
443     ...     print(x)
444     ()
445     (1,)
446     (2,)
447     (3,)
448     (1, 2)
449     (1, 3)
450     (2, 3)
451     (1, 2, 3)
452     """
453     return chain.from_iterable(combinations(seq, r) for r in range(len(seq) + 1))
454
455
456 if __name__ == "__main__":
457     import doctest
458
459     doctest.testmod()