More cleanup, yey!
[python_utils.git] / list_utils.py
index 182e2bc5c104908f39a15e4675021e6ed8a7c338..91af8f9eb924fb7e7e04932d58a1bcb6eded0690 100644 (file)
@@ -1,8 +1,10 @@
 #!/usr/bin/env python3
 
 #!/usr/bin/env python3
 
+"""Some useful(?) utilities for dealing with Lists."""
+
 from collections import Counter
 from itertools import islice
 from collections import Counter
 from itertools import islice
-from typing import Any, Iterator, List, Mapping, Sequence
+from typing import Any, Iterator, List, Sequence, Tuple
 
 
 def shard(lst: List[Any], size: int) -> Iterator[Any]:
 
 
 def shard(lst: List[Any], size: int) -> Iterator[Any]:
@@ -48,7 +50,24 @@ def prepend(item: Any, lst: List[Any]) -> List[Any]:
     return lst
 
 
     return lst
 
 
-def population_counts(lst: List[Any]) -> Mapping[Any, int]:
+def remove_list_if_one_element(lst: List[Any]) -> Any:
+    """
+    Remove the list and return the 0th element iff its length is one.
+
+    >>> remove_list_if_one_element([1234])
+    1234
+
+    >>> remove_list_if_one_element([1, 2, 3, 4])
+    [1, 2, 3, 4]
+
+    """
+    if len(lst) == 1:
+        return lst[0]
+    else:
+        return lst
+
+
+def population_counts(lst: Sequence[Any]) -> Counter:
     """
     Return a population count mapping for the list (i.e. the keys are
     list items and the values are the number of occurrances of that
     """
     Return a population count mapping for the list (i.e. the keys are
     list items and the values are the number of occurrances of that
@@ -61,29 +80,39 @@ def population_counts(lst: List[Any]) -> Mapping[Any, int]:
     return Counter(lst)
 
 
     return Counter(lst)
 
 
-def most_common_item(lst: List[Any]) -> Any:
+def most_common(lst: List[Any], *, count=1) -> Any:
 
     """
     Return the most common item in the list.  In the case of ties,
     which most common item is returned will be random.
 
 
     """
     Return the most common item in the list.  In the case of ties,
     which most common item is returned will be random.
 
-    >>> most_common_item([1, 1, 1, 2, 2, 3, 3, 3, 3, 4, 4])
+    >>> most_common([1, 1, 1, 2, 2, 3, 3, 3, 3, 4, 4])
     3
 
     3
 
+    >>> most_common([1, 1, 1, 2, 2, 3, 3, 3, 3, 4, 4], count=2)
+    [3, 1]
+
     """
     """
-    return population_counts(lst).most_common(1)[0][0]
+    p = population_counts(lst)
+    return remove_list_if_one_element([_[0] for _ in p.most_common()[0:count]])
 
 
 
 
-def least_common_item(lst: List[Any]) -> Any:
+def least_common(lst: List[Any], *, count=1) -> Any:
     """
     Return the least common item in the list.  In the case of
     ties, which least common item is returned will be random.
 
     """
     Return the least common item in the list.  In the case of
     ties, which least common item is returned will be random.
 
-    >>> least_common_item([1, 1, 1, 2, 2, 3, 3, 3, 4])
+    >>> least_common([1, 1, 1, 2, 2, 3, 3, 3, 4])
     4
 
     4
 
+    >>> least_common([1, 1, 1, 2, 2, 3, 3, 3, 4], count=2)
+    [4, 2]
+
     """
     """
-    return population_counts(lst).most_common()[-1][0]
+    p = population_counts(lst)
+    mc = p.most_common()[-count:]
+    mc.reverse()
+    return remove_list_if_one_element([_[0] for _ in mc])
 
 
 def dedup_list(lst: List[Any]) -> List[Any]:
 
 
 def dedup_list(lst: List[Any]) -> List[Any]:
@@ -100,16 +129,158 @@ def dedup_list(lst: List[Any]) -> List[Any]:
 def uniq(lst: List[Any]) -> List[Any]:
     """
     Alias for dedup_list.
 def uniq(lst: List[Any]) -> List[Any]:
     """
     Alias for dedup_list.
-
     """
     return dedup_list(lst)
 
 
     """
     return dedup_list(lst)
 
 
+def contains_duplicates(lst: List[Any]) -> bool:
+    """
+    Does the list contian duplicate elements or not?
+
+    >>> lst = [1, 2, 1, 3, 3, 4, 4, 5, 6, 1, 3, 4]
+    >>> contains_duplicates(lst)
+    True
+
+    >>> contains_duplicates(dedup_list(lst))
+    False
+
+    """
+    seen = set()
+    for _ in lst:
+        if _ in seen:
+            return True
+        seen.add(_)
+    return False
+
+
+def all_unique(lst: List[Any]) -> bool:
+    """
+    Inverted alias for contains_duplicates.
+    """
+    return not contains_duplicates(lst)
+
+
+def transpose(lst: List[Any]) -> List[Any]:
+    """
+    Transpose a list of lists.
+
+    >>> lst = [[1, 2], [3, 4], [5, 6]]
+    >>> transpose(lst)
+    [[1, 3, 5], [2, 4, 6]]
+
+    """
+    transposed = zip(*lst)
+    return [list(_) for _ in transposed]
+
+
 def ngrams(lst: Sequence[Any], n):
 def ngrams(lst: Sequence[Any], n):
+    """
+    Return the ngrams in the sequence.
+
+    >>> seq = 'encyclopedia'
+    >>> for _ in ngrams(seq, 3):
+    ...     _
+    'enc'
+    'ncy'
+    'cyc'
+    'ycl'
+    'clo'
+    'lop'
+    'ope'
+    'ped'
+    'edi'
+    'dia'
+
+    >>> seq = ['this', 'is', 'an', 'awesome', 'test']
+    >>> for _ in ngrams(seq, 3):
+    ...     _
+    ['this', 'is', 'an']
+    ['is', 'an', 'awesome']
+    ['an', 'awesome', 'test']
+    """
     for i in range(len(lst) - n + 1):
     for i in range(len(lst) - n + 1):
-        yield lst[i:i + n]
+        yield lst[i : i + n]
+
+
+def permute(seq: str):
+    """
+    Returns all permutations of a sequence; takes O(N!) time.
+
+    >>> for x in permute('cat'):
+    ...     print(x)
+    cat
+    cta
+    act
+    atc
+    tca
+    tac
+
+    """
+    yield from _permute(seq, "")
+
+
+def _permute(seq: str, path: str):
+    seq_len = len(seq)
+    if seq_len == 0:
+        yield path
+
+    for i in range(seq_len):
+        car = seq[i]
+        left = seq[0:i]
+        right = seq[i + 1 :]
+        cdr = left + right
+        yield from _permute(cdr, path + car)
+
+
+def binary_search(lst: Sequence[Any], target: Any, *, sanity_check=False) -> Tuple[bool, int]:
+    """Performs a binary search on lst (which must already be sorted).
+    Returns a Tuple composed of a bool which indicates whether the
+    target was found and an int which indicates the index closest to
+    target whether it was found or not.
+
+    >>> a = [1, 4, 5, 6, 7, 9, 10, 11]
+    >>> binary_search(a, 4)
+    (True, 1)
+
+    >>> binary_search(a, 12)
+    (False, 8)
+
+    >>> binary_search(a, 3)
+    (False, 1)
+
+    >>> binary_search(a, 2)
+    (False, 1)
+
+    >>> a.append(9)
+    >>> binary_search(a, 4, sanity_check=True)
+    Traceback (most recent call last):
+    ...
+    AssertionError
+
+    """
+    if sanity_check:
+        last = None
+        for x in lst:
+            if last is not None:
+                assert x >= last  # This asserts iff the list isn't sorted
+            last = x  # in ascending order.
+    return _binary_search(lst, target, 0, len(lst) - 1)
+
+
+def _binary_search(lst: Sequence[Any], target: Any, low: int, high: int) -> Tuple[bool, int]:
+    if high >= low:
+        mid = (high + low) // 2
+        if lst[mid] == target:
+            return (True, mid)
+        elif lst[mid] > target:
+            return _binary_search(lst, target, low, mid - 1)
+        else:
+            return _binary_search(lst, target, mid + 1, high)
+    else:
+        return (False, low)
 
 
 if __name__ == '__main__':
     import doctest
 
 
 if __name__ == '__main__':
     import doctest
+
     doctest.testmod()
     doctest.testmod()