Since this thing is on the innerwebs I suppose it should have a
[python_utils.git] / string_utils.py
index 6f3cc90ed46f5c238b0887848c1cf7504ec3bcc0..08995765411a22bf5272bdae4bdee41b20312bc0 100644 (file)
@@ -1,4 +1,5 @@
 #!/usr/bin/env python3
+# -*- coding: utf-8 -*-
 
 """The MIT License (MIT)
 
@@ -39,7 +40,17 @@ import string
 import unicodedata
 import warnings
 from itertools import zip_longest
-from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    Iterable,
+    List,
+    Literal,
+    Optional,
+    Sequence,
+    Tuple,
+)
 from uuid import uuid4
 
 import list_utils
@@ -141,16 +152,16 @@ MARGIN_RE = re.compile(r"^[^\S\r\n]+")
 ESCAPE_SEQUENCE_RE = re.compile(r"\e\[[^A-Za-z]*[A-Za-z]")
 
 NUM_SUFFIXES = {
-    "Pb": (1024 ** 5),
-    "P": (1024 ** 5),
-    "Tb": (1024 ** 4),
-    "T": (1024 ** 4),
-    "Gb": (1024 ** 3),
-    "G": (1024 ** 3),
-    "Mb": (1024 ** 2),
-    "M": (1024 ** 2),
-    "Kb": (1024 ** 1),
-    "K": (1024 ** 1),
+    "Pb": (1024**5),
+    "P": (1024**5),
+    "Tb": (1024**4),
+    "T": (1024**4),
+    "Gb": (1024**3),
+    "G": (1024**3),
+    "Mb": (1024**2),
+    "M": (1024**2),
+    "Kb": (1024**1),
+    "K": (1024**1),
 }
 
 
@@ -960,6 +971,10 @@ def shuffle(in_str: str) -> str:
     return from_char_list(chars)
 
 
+def scramble(in_str: str) -> str:
+    return shuffle(in_str)
+
+
 def strip_html(in_str: str, keep_tag_content: bool = False) -> str:
     """
     Remove html code contained into the given string.
@@ -1074,13 +1089,13 @@ def to_date(in_str: str) -> Optional[datetime.date]:
     """
     Parses a date string.  See DateParser docs for details.
     """
-    import dateparse.dateparse_utils as dp  # type: ignore
+    import dateparse.dateparse_utils as du
 
     try:
-        d = dp.DateParser()
+        d = du.DateParser()  # type: ignore
         d.parse(in_str)
         return d.get_date()
-    except dp.ParseException:
+    except du.ParseException:  # type: ignore
         msg = f'Unable to parse date {in_str}.'
         logger.warning(msg)
     return None
@@ -1093,10 +1108,10 @@ def valid_date(in_str: str) -> bool:
     import dateparse.dateparse_utils as dp
 
     try:
-        d = dp.DateParser()
+        d = dp.DateParser()  # type: ignore
         _ = d.parse(in_str)
         return True
-    except dp.ParseException:
+    except dp.ParseException:  # type: ignore
         msg = f'Unable to parse date {in_str}.'
         logger.warning(msg)
     return False
@@ -1109,9 +1124,9 @@ def to_datetime(in_str: str) -> Optional[datetime.datetime]:
     import dateparse.dateparse_utils as dp
 
     try:
-        d = dp.DateParser()
+        d = dp.DateParser()  # type: ignore
         dt = d.parse(in_str)
-        if type(dt) == datetime.datetime:
+        if isinstance(dt, datetime.datetime):
             return dt
     except ValueError:
         msg = f'Unable to parse datetime {in_str}.'
@@ -1207,7 +1222,23 @@ def sprintf(*args, **kwargs) -> str:
     return ret
 
 
-class SprintfStdout(object):
+def strip_ansi_sequences(in_str: str) -> str:
+    """Strips ANSI sequences out of strings.
+
+    >>> import ansi as a
+    >>> s = a.fg('blue') + 'blue!' + a.reset()
+    >>> len(s)   # '\x1b[38;5;21mblue!\x1b[m'
+    18
+    >>> len(strip_ansi_sequences(s))
+    5
+    >>> strip_ansi_sequences(s)
+    'blue!'
+
+    """
+    return re.sub(r'\x1b\[[\d+;]*[a-z]', '', in_str)
+
+
+class SprintfStdout(contextlib.AbstractContextManager):
     """
     A context manager that captures outputs to stdout.
 
@@ -1227,10 +1258,10 @@ class SprintfStdout(object):
         self.recorder.__enter__()
         return lambda: self.destination.getvalue()
 
-    def __exit__(self, *args) -> None:
+    def __exit__(self, *args) -> Literal[False]:
         self.recorder.__exit__(*args)
         self.destination.seek(0)
-        return None  # don't suppress exceptions
+        return False
 
 
 def capitalize_first_letter(txt: str) -> str:
@@ -1370,7 +1401,7 @@ def make_contractions(txt: str) -> str:
             for second in second_list:
                 # Disallow there're/where're.  They're valid English
                 # but sound weird.
-                if (first == 'there' or first == 'where') and second == 'a(re)':
+                if (first in ('there', 'where')) and second == 'a(re)':
                     continue
 
                 pattern = fr'\b({first})\s+{second}\b'
@@ -1456,11 +1487,11 @@ def shuffle_columns_into_list(
     # Column specs map input lines' columns into outputs.
     # [col1, col2...]
     for spec in column_specs:
-        chunk = ''
+        hunk = ''
         for n in spec:
-            chunk = chunk + delim + input_lines[n]
-        chunk = chunk.strip(delim)
-        out.append(chunk)
+            hunk = hunk + delim + input_lines[n]
+        hunk = hunk.strip(delim)
+        out.append(hunk)
     return out
 
 
@@ -1486,11 +1517,11 @@ def shuffle_columns_into_dict(
     # Column specs map input lines' columns into outputs.
     # "key", [col1, col2...]
     for spec in column_specs:
-        chunk = ''
+        hunk = ''
         for n in spec[1]:
-            chunk = chunk + delim + input_lines[n]
-        chunk = chunk.strip(delim)
-        out[spec[0]] = chunk
+            hunk = hunk + delim + input_lines[n]
+        hunk = hunk.strip(delim)
+        out[spec[0]] = hunk
     return out
 
 
@@ -1515,9 +1546,9 @@ def to_ascii(x: str):
     b'1, 2, 3'
 
     """
-    if type(x) is str:
+    if isinstance(x, str):
         return x.encode('ascii')
-    if type(x) is bytes:
+    if isinstance(x, bytes):
         return x
     raise Exception('to_ascii works with strings and bytes')
 
@@ -1638,7 +1669,7 @@ def ip_v4_sort_key(txt: str) -> Optional[Tuple[int, ...]]:
     if not is_ip_v4(txt):
         print(f"not IP: {txt}")
         return None
-    return tuple([int(x) for x in txt.split('.')])
+    return tuple(int(x) for x in txt.split('.'))
 
 
 def path_ancestors_before_descendants_sort_key(volume: str) -> Tuple[str, ...]:
@@ -1653,7 +1684,7 @@ def path_ancestors_before_descendants_sort_key(volume: str) -> Tuple[str, ...]:
     ['/usr', '/usr/local', '/usr/local/bin']
 
     """
-    return tuple([x for x in volume.split('/') if len(x) > 0])
+    return tuple(x for x in volume.split('/') if len(x) > 0)
 
 
 def replace_all(in_str: str, replace_set: str, replacement: str) -> str:
@@ -1669,6 +1700,20 @@ def replace_all(in_str: str, replace_set: str, replacement: str) -> str:
     return in_str
 
 
+def replace_nth(in_str: str, source: str, target: str, nth: int):
+    """Replaces the nth occurrance of a substring within a string.
+
+    >>> replace_nth('this is a test', ' ', '-', 3)
+    'this is a-test'
+
+    """
+    where = [m.start() for m in re.finditer(source, in_str)][nth - 1]
+    before = in_str[:where]
+    after = in_str[where:]
+    after = after.replace(source, target, 1)
+    return before + after
+
+
 if __name__ == '__main__':
     import doctest