More cleanup.
[python_utils.git] / string_utils.py
index 244c450b5ac4b1ec89d8388de6d32eb592b40225..4bec031738e989d10507992387e18aa47996da8e 100644 (file)
@@ -1,29 +1,57 @@
 #!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+"""The MIT License (MIT)
+
+Copyright (c) 2016-2020 Davide Zanotti
+Modifications Copyright (c) 2021-2022 Scott Gasch
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+This class is based on: https://github.com/daveoncode/python-string-utils.
+"""
 
 import base64
-import contextlib
+import contextlib  # type: ignore
 import datetime
 import io
-from itertools import zip_longest
 import json
 import logging
 import numbers
 import random
 import re
 import string
+import unicodedata
+import warnings
+from itertools import zip_longest
 from typing import (
     Any,
     Callable,
     Dict,
     Iterable,
     List,
+    Literal,
     Optional,
     Sequence,
     Tuple,
 )
-import unicodedata
 from uuid import uuid4
-import warnings
 
 import list_utils
 
@@ -61,19 +89,13 @@ EMAIL_RE = re.compile(r"^{}$".format(EMAILS_RAW_STRING))
 
 EMAILS_RE = re.compile(r"({})".format(EMAILS_RAW_STRING))
 
-CAMEL_CASE_TEST_RE = re.compile(
-    r"^[a-zA-Z]*([a-z]+[A-Z]+|[A-Z]+[a-z]+)[a-zA-Z\d]*$"
-)
+CAMEL_CASE_TEST_RE = re.compile(r"^[a-zA-Z]*([a-z]+[A-Z]+|[A-Z]+[a-z]+)[a-zA-Z\d]*$")
 
 CAMEL_CASE_REPLACE_RE = re.compile(r"([a-z]|[A-Z]+)(?=[A-Z])")
 
-SNAKE_CASE_TEST_RE = re.compile(
-    r"^([a-z]+\d*_[a-z\d_]*|_+[a-z\d]+[a-z\d_]*)$", re.IGNORECASE
-)
+SNAKE_CASE_TEST_RE = re.compile(r"^([a-z]+\d*_[a-z\d_]*|_+[a-z\d]+[a-z\d_]*)$", re.IGNORECASE)
 
-SNAKE_CASE_TEST_DASH_RE = re.compile(
-    r"([a-z]+\d*-[a-z\d-]*|-+[a-z\d]+[a-z\d-]*)$", re.IGNORECASE
-)
+SNAKE_CASE_TEST_DASH_RE = re.compile(r"([a-z]+\d*-[a-z\d-]*|-+[a-z\d]+[a-z\d-]*)$", re.IGNORECASE)
 
 SNAKE_CASE_REPLACE_RE = re.compile(r"(_)([a-z\d])")
 
@@ -88,13 +110,9 @@ CREDIT_CARDS = {
     "JCB": re.compile(r"^(?:2131|1800|35\d{3})\d{11}$"),
 }
 
-JSON_WRAPPER_RE = re.compile(
-    r"^\s*[\[{]\s*(.*)\s*[\}\]]\s*$", re.MULTILINE | re.DOTALL
-)
+JSON_WRAPPER_RE = re.compile(r"^\s*[\[{]\s*(.*)\s*[\}\]]\s*$", re.MULTILINE | re.DOTALL)
 
-UUID_RE = re.compile(
-    r"^[a-f\d]{8}-[a-f\d]{4}-[a-f\d]{4}-[a-f\d]{4}-[a-f\d]{12}$", re.IGNORECASE
-)
+UUID_RE = re.compile(r"^[a-f\d]{8}-[a-f\d]{4}-[a-f\d]{4}-[a-f\d]{4}-[a-f\d]{12}$", re.IGNORECASE)
 
 UUID_HEX_OK_RE = re.compile(
     r"^[a-f\d]{8}-?[a-f\d]{4}-?[a-f\d]{4}-?[a-f\d]{4}-?[a-f\d]{12}$",
@@ -109,17 +127,11 @@ IP_V6_RE = re.compile(r"^([a-z\d]{0,4}:){7}[a-z\d]{0,4}$", re.IGNORECASE)
 
 ANYWHERE_IP_V6_RE = re.compile(r"([a-z\d]{0,4}:){7}[a-z\d]{0,4}", re.IGNORECASE)
 
-MAC_ADDRESS_RE = re.compile(
-    r"^([0-9A-F]{2}[:-]){5}([0-9A-F]{2})$", re.IGNORECASE
-)
+MAC_ADDRESS_RE = re.compile(r"^([0-9A-F]{2}[:-]){5}([0-9A-F]{2})$", re.IGNORECASE)
 
-ANYWHERE_MAC_ADDRESS_RE = re.compile(
-    r"([0-9A-F]{2}[:-]){5}([0-9A-F]{2})", re.IGNORECASE
-)
+ANYWHERE_MAC_ADDRESS_RE = re.compile(r"([0-9A-F]{2}[:-]){5}([0-9A-F]{2})", re.IGNORECASE)
 
-WORDS_COUNT_RE = re.compile(
-    r"\W*[^\W_]+\W*", re.IGNORECASE | re.MULTILINE | re.UNICODE
-)
+WORDS_COUNT_RE = re.compile(r"\W*[^\W_]+\W*", re.IGNORECASE | re.MULTILINE | re.UNICODE)
 
 HTML_RE = re.compile(
     r"((<([a-z]+:)?[a-z]+[^>]*/?>)(.*?(</([a-z]+:)?[a-z]+>))?|<!--.*-->|<!doctype.*>)",
@@ -133,25 +145,23 @@ HTML_TAG_ONLY_RE = re.compile(
 
 SPACES_RE = re.compile(r"\s")
 
-NO_LETTERS_OR_NUMBERS_RE = re.compile(
-    r"[^\w\d]+|_+", re.IGNORECASE | re.UNICODE
-)
+NO_LETTERS_OR_NUMBERS_RE = re.compile(r"[^\w\d]+|_+", re.IGNORECASE | re.UNICODE)
 
 MARGIN_RE = re.compile(r"^[^\S\r\n]+")
 
 ESCAPE_SEQUENCE_RE = re.compile(r"\e\[[^A-Za-z]*[A-Za-z]")
 
 NUM_SUFFIXES = {
-    "Pb": (1024 ** 5),
-    "P": (1024 ** 5),
-    "Tb": (1024 ** 4),
-    "T": (1024 ** 4),
-    "Gb": (1024 ** 3),
-    "G": (1024 ** 3),
-    "Mb": (1024 ** 2),
-    "M": (1024 ** 2),
-    "Kb": (1024 ** 1),
-    "K": (1024 ** 1),
+    "Pb": (1024**5),
+    "P": (1024**5),
+    "Tb": (1024**4),
+    "T": (1024**4),
+    "Gb": (1024**3),
+    "G": (1024**3),
+    "Mb": (1024**2),
+    "M": (1024**2),
+    "Kb": (1024**1),
+    "K": (1024**1),
 }
 
 
@@ -390,9 +400,7 @@ def strip_escape_sequences(in_str: str) -> str:
     return in_str
 
 
-def add_thousands_separator(
-    in_str: str, *, separator_char=',', places=3
-) -> str:
+def add_thousands_separator(in_str: str, *, separator_char=',', places=3) -> str:
     """
     Add thousands separator to a numeric string.  Also handles numbers.
 
@@ -411,22 +419,16 @@ def add_thousands_separator(
     if isinstance(in_str, numbers.Number):
         in_str = f'{in_str}'
     if is_number(in_str):
-        return _add_thousands_separator(
-            in_str, separator_char=separator_char, places=places
-        )
+        return _add_thousands_separator(in_str, separator_char=separator_char, places=places)
     raise ValueError(in_str)
 
 
-def _add_thousands_separator(
-    in_str: str, *, separator_char=',', places=3
-) -> str:
+def _add_thousands_separator(in_str: str, *, separator_char=',', places=3) -> str:
     decimal_part = ""
     if '.' in in_str:
         (in_str, decimal_part) = in_str.split('.')
     tmp = [iter(in_str[::-1])] * places
-    ret = separator_char.join(
-        "".join(x) for x in zip_longest(*tmp, fillvalue="")
-    )[::-1]
+    ret = separator_char.join("".join(x) for x in zip_longest(*tmp, fillvalue=""))[::-1]
     if len(decimal_part) > 0:
         ret += '.'
         ret += decimal_part
@@ -467,11 +469,7 @@ def is_email(in_str: Any) -> bool:
     >>> is_email('@gmail.com')
     False
     """
-    if (
-        not is_full_string(in_str)
-        or len(in_str) > 320
-        or in_str.startswith(".")
-    ):
+    if not is_full_string(in_str) or len(in_str) > 320 or in_str.startswith("."):
         return False
 
     try:
@@ -481,12 +479,7 @@ def is_email(in_str: Any) -> bool:
 
         # head's size must be <= 64, tail <= 255, head must not start
         # with a dot or contain multiple consecutive dots.
-        if (
-            len(head) > 64
-            or len(tail) > 255
-            or head.endswith(".")
-            or (".." in head)
-        ):
+        if len(head) > 64 or len(tail) > 255 or head.endswith(".") or (".." in head):
             return False
 
         # removes escaped spaces, so that later on the test regex will
@@ -603,9 +596,7 @@ def is_camel_case(in_str: Any) -> bool:
     - it contains both lowercase and uppercase letters
     - it does not start with a number
     """
-    return (
-        is_full_string(in_str) and CAMEL_CASE_TEST_RE.match(in_str) is not None
-    )
+    return is_full_string(in_str) and CAMEL_CASE_TEST_RE.match(in_str) is not None
 
 
 def is_snake_case(in_str: Any, *, separator: str = "_") -> bool:
@@ -630,14 +621,10 @@ def is_snake_case(in_str: Any, *, separator: str = "_") -> bool:
     """
     if is_full_string(in_str):
         re_map = {"_": SNAKE_CASE_TEST_RE, "-": SNAKE_CASE_TEST_DASH_RE}
-        re_template = (
-            r"([a-z]+\d*{sign}[a-z\d{sign}]*|{sign}+[a-z\d]+[a-z\d{sign}]*)"
-        )
+        re_template = r"([a-z]+\d*{sign}[a-z\d{sign}]*|{sign}+[a-z\d]+[a-z\d{sign}]*)"
         r = re_map.get(
             separator,
-            re.compile(
-                re_template.format(sign=re.escape(separator)), re.IGNORECASE
-            ),
+            re.compile(re_template.format(sign=re.escape(separator)), re.IGNORECASE),
         )
         return r.match(in_str) is not None
     return False
@@ -926,9 +913,7 @@ def camel_case_to_snake_case(in_str, *, separator="_"):
         raise ValueError(in_str)
     if not is_camel_case(in_str):
         return in_str
-    return CAMEL_CASE_REPLACE_RE.sub(
-        lambda m: m.group(1) + separator, in_str
-    ).lower()
+    return CAMEL_CASE_REPLACE_RE.sub(lambda m: m.group(1) + separator, in_str).lower()
 
 
 def snake_case_to_camel_case(
@@ -1100,13 +1085,13 @@ def to_date(in_str: str) -> Optional[datetime.date]:
     """
     Parses a date string.  See DateParser docs for details.
     """
-    import dateparse.dateparse_utils as dp
+    import dateparse.dateparse_utils as du
 
     try:
-        d = dp.DateParser()
+        d = du.DateParser()  # type: ignore
         d.parse(in_str)
         return d.get_date()
-    except dp.ParseException:
+    except du.ParseException:  # type: ignore
         msg = f'Unable to parse date {in_str}.'
         logger.warning(msg)
     return None
@@ -1119,10 +1104,10 @@ def valid_date(in_str: str) -> bool:
     import dateparse.dateparse_utils as dp
 
     try:
-        d = dp.DateParser()
+        d = dp.DateParser()  # type: ignore
         _ = d.parse(in_str)
         return True
-    except dp.ParseException:
+    except dp.ParseException:  # type: ignore
         msg = f'Unable to parse date {in_str}.'
         logger.warning(msg)
     return False
@@ -1135,9 +1120,9 @@ def to_datetime(in_str: str) -> Optional[datetime.datetime]:
     import dateparse.dateparse_utils as dp
 
     try:
-        d = dp.DateParser()
+        d = dp.DateParser()  # type: ignore
         dt = d.parse(in_str)
-        if type(dt) == datetime.datetime:
+        if isinstance(dt, datetime.datetime):
             return dt
     except ValueError:
         msg = f'Unable to parse datetime {in_str}.'
@@ -1233,7 +1218,7 @@ def sprintf(*args, **kwargs) -> str:
     return ret
 
 
-class SprintfStdout(object):
+class SprintfStdout(contextlib.AbstractContextManager):
     """
     A context manager that captures outputs to stdout.
 
@@ -1246,17 +1231,43 @@ class SprintfStdout(object):
 
     def __init__(self) -> None:
         self.destination = io.StringIO()
-        self.recorder = None
+        self.recorder: contextlib.redirect_stdout
 
     def __enter__(self) -> Callable[[], str]:
         self.recorder = contextlib.redirect_stdout(self.destination)
         self.recorder.__enter__()
         return lambda: self.destination.getvalue()
 
-    def __exit__(self, *args) -> None:
+    def __exit__(self, *args) -> Literal[False]:
         self.recorder.__exit__(*args)
         self.destination.seek(0)
-        return None  # don't suppress exceptions
+        return False
+
+
+def capitalize_first_letter(txt: str) -> str:
+    """Capitalize the first letter of a string.
+
+    >>> capitalize_first_letter('test')
+    'Test'
+    >>> capitalize_first_letter("ALREADY!")
+    'ALREADY!'
+
+    """
+    return txt[0].upper() + txt[1:]
+
+
+def it_they(n: int) -> str:
+    """It or they?
+
+    >>> it_they(1)
+    'it'
+    >>> it_they(100)
+    'they'
+
+    """
+    if n == 1:
+        return "it"
+    return "they"
 
 
 def is_are(n: int) -> str:
@@ -1291,6 +1302,98 @@ def pluralize(n: int) -> str:
     return "s"
 
 
+def make_contractions(txt: str) -> str:
+    """Glue words together to form contractions.
+
+    >>> make_contractions('It is nice today.')
+    "It's nice today."
+
+    >>> make_contractions('I can    not even...')
+    "I can't even..."
+
+    >>> make_contractions('She could not see!')
+    "She couldn't see!"
+
+    >>> make_contractions('But she will not go.')
+    "But she won't go."
+
+    >>> make_contractions('Verily, I shall not.')
+    "Verily, I shan't."
+
+    >>> make_contractions('No you cannot.')
+    "No you can't."
+
+    >>> make_contractions('I said you can not go.')
+    "I said you can't go."
+
+    """
+
+    first_second = [
+        (
+            [
+                'are',
+                'could',
+                'did',
+                'has',
+                'have',
+                'is',
+                'must',
+                'should',
+                'was',
+                'were',
+                'would',
+            ],
+            ['(n)o(t)'],
+        ),
+        (
+            [
+                "I",
+                "you",
+                "he",
+                "she",
+                "it",
+                "we",
+                "they",
+                "how",
+                "why",
+                "when",
+                "where",
+                "who",
+                "there",
+            ],
+            ['woul(d)', 'i(s)', 'a(re)', 'ha(s)', 'ha(ve)', 'ha(d)', 'wi(ll)'],
+        ),
+    ]
+
+    # Special cases: can't, shan't and won't.
+    txt = re.sub(r'\b(can)\s*no(t)\b', r"\1'\2", txt, count=0, flags=re.IGNORECASE)
+    txt = re.sub(r'\b(sha)ll\s*(n)o(t)\b', r"\1\2'\3", txt, count=0, flags=re.IGNORECASE)
+    txt = re.sub(
+        r'\b(w)ill\s*(n)(o)(t)\b',
+        r"\1\3\2'\4",
+        txt,
+        count=0,
+        flags=re.IGNORECASE,
+    )
+
+    for first_list, second_list in first_second:
+        for first in first_list:
+            for second in second_list:
+                # Disallow there're/where're.  They're valid English
+                # but sound weird.
+                if (first in ('there', 'where')) and second == 'a(re)':
+                    continue
+
+                pattern = fr'\b({first})\s+{second}\b'
+                if second == '(n)o(t)':
+                    replacement = r"\1\2'\3"
+                else:
+                    replacement = r"\1'\2"
+                txt = re.sub(pattern, replacement, txt, count=0, flags=re.IGNORECASE)
+
+    return txt
+
+
 def thify(n: int) -> str:
     """Return the proper cardinal suffix for a number.
 
@@ -1343,7 +1446,7 @@ def trigrams(txt: str):
 
 
 def shuffle_columns_into_list(
-    input_lines: Iterable[str], column_specs: Iterable[Iterable[int]], delim=''
+    input_lines: Sequence[str], column_specs: Iterable[Iterable[int]], delim=''
 ) -> Iterable[str]:
     """Helper to shuffle / parse columnar data and return the results as a
     list.  The column_specs argument is an iterable collection of
@@ -1364,16 +1467,16 @@ def shuffle_columns_into_list(
     # Column specs map input lines' columns into outputs.
     # [col1, col2...]
     for spec in column_specs:
-        chunk = ''
+        hunk = ''
         for n in spec:
-            chunk = chunk + delim + input_lines[n]
-        chunk = chunk.strip(delim)
-        out.append(chunk)
+            hunk = hunk + delim + input_lines[n]
+        hunk = hunk.strip(delim)
+        out.append(hunk)
     return out
 
 
 def shuffle_columns_into_dict(
-    input_lines: Iterable[str],
+    input_lines: Sequence[str],
     column_specs: Iterable[Tuple[str, Iterable[int]]],
     delim='',
 ) -> Dict[str, str]:
@@ -1394,11 +1497,11 @@ def shuffle_columns_into_dict(
     # Column specs map input lines' columns into outputs.
     # "key", [col1, col2...]
     for spec in column_specs:
-        chunk = ''
+        hunk = ''
         for n in spec[1]:
-            chunk = chunk + delim + input_lines[n]
-        chunk = chunk.strip(delim)
-        out[spec[0]] = chunk
+            hunk = hunk + delim + input_lines[n]
+        hunk = hunk.strip(delim)
+        out[spec[0]] = hunk
     return out
 
 
@@ -1423,14 +1526,14 @@ def to_ascii(x: str):
     b'1, 2, 3'
 
     """
-    if type(x) is str:
+    if isinstance(x, str):
         return x.encode('ascii')
-    if type(x) is bytes:
+    if isinstance(x, bytes):
         return x
     raise Exception('to_ascii works with strings and bytes')
 
 
-def to_base64(txt: str, *, encoding='utf-8', errors='surrogatepass') -> str:
+def to_base64(txt: str, *, encoding='utf-8', errors='surrogatepass') -> bytes:
     """Encode txt and then encode the bytes with a 64-character
     alphabet.  This is compatible with uudecode.
 
@@ -1463,7 +1566,7 @@ def is_base64(txt: str) -> bool:
     return True
 
 
-def from_base64(b64: str, encoding='utf-8', errors='surrogatepass') -> str:
+def from_base64(b64: bytes, encoding='utf-8', errors='surrogatepass') -> str:
     """Convert base64 encoded string back to normal strings.
 
     >>> from_base64(b'aGVsbG8/\\n')
@@ -1488,9 +1591,7 @@ def chunk(txt: str, chunk_size):
         yield txt[x : x + chunk_size]
 
 
-def to_bitstring(
-    txt: str, *, delimiter='', encoding='utf-8', errors='surrogatepass'
-) -> str:
+def to_bitstring(txt: str, *, delimiter='', encoding='utf-8', errors='surrogatepass') -> str:
     """Encode txt and then chop it into bytes.  Note: only bitstrings
     with delimiter='' are interpretable by from_bitstring.
 
@@ -1531,13 +1632,10 @@ def from_bitstring(bits: str, encoding='utf-8', errors='surrogatepass') -> str:
 
     """
     n = int(bits, 2)
-    return (
-        n.to_bytes((n.bit_length() + 7) // 8, 'big').decode(encoding, errors)
-        or '\0'
-    )
+    return n.to_bytes((n.bit_length() + 7) // 8, 'big').decode(encoding, errors) or '\0'
 
 
-def ip_v4_sort_key(txt: str) -> Tuple[int]:
+def ip_v4_sort_key(txt: str) -> Optional[Tuple[int, ...]]:
     """Turn an IPv4 address into a tuple for sorting purposes.
 
     >>> ip_v4_sort_key('10.0.0.18')
@@ -1554,7 +1652,7 @@ def ip_v4_sort_key(txt: str) -> Tuple[int]:
     return tuple([int(x) for x in txt.split('.')])
 
 
-def path_ancestors_before_descendants_sort_key(volume: str) -> Tuple[str]:
+def path_ancestors_before_descendants_sort_key(volume: str) -> Tuple[str, ...]:
     """Chunk up a file path so that parent/ancestor paths sort before
     children/descendant paths.