More cleanup.
[python_utils.git] / string_utils.py
index 3c97ff7991726012ff091dc7aa042abb88000bf6..4bec031738e989d10507992387e18aa47996da8e 100644 (file)
@@ -40,7 +40,17 @@ import string
 import unicodedata
 import warnings
 from itertools import zip_longest
-from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    Iterable,
+    List,
+    Literal,
+    Optional,
+    Sequence,
+    Tuple,
+)
 from uuid import uuid4
 
 import list_utils
@@ -1091,7 +1101,6 @@ def valid_date(in_str: str) -> bool:
     """
     True if the string represents a valid date.
     """
-    import dateparse
     import dateparse.dateparse_utils as dp
 
     try:
@@ -1113,7 +1122,7 @@ def to_datetime(in_str: str) -> Optional[datetime.datetime]:
     try:
         d = dp.DateParser()  # type: ignore
         dt = d.parse(in_str)
-        if type(dt) == datetime.datetime:
+        if isinstance(dt, datetime.datetime):
             return dt
     except ValueError:
         msg = f'Unable to parse datetime {in_str}.'
@@ -1209,7 +1218,7 @@ def sprintf(*args, **kwargs) -> str:
     return ret
 
 
-class SprintfStdout(object):
+class SprintfStdout(contextlib.AbstractContextManager):
     """
     A context manager that captures outputs to stdout.
 
@@ -1229,10 +1238,10 @@ class SprintfStdout(object):
         self.recorder.__enter__()
         return lambda: self.destination.getvalue()
 
-    def __exit__(self, *args) -> None:
+    def __exit__(self, *args) -> Literal[False]:
         self.recorder.__exit__(*args)
         self.destination.seek(0)
-        return None  # don't suppress exceptions
+        return False
 
 
 def capitalize_first_letter(txt: str) -> str:
@@ -1372,7 +1381,7 @@ def make_contractions(txt: str) -> str:
             for second in second_list:
                 # Disallow there're/where're.  They're valid English
                 # but sound weird.
-                if (first == 'there' or first == 'where') and second == 'a(re)':
+                if (first in ('there', 'where')) and second == 'a(re)':
                     continue
 
                 pattern = fr'\b({first})\s+{second}\b'
@@ -1458,11 +1467,11 @@ def shuffle_columns_into_list(
     # Column specs map input lines' columns into outputs.
     # [col1, col2...]
     for spec in column_specs:
-        chunk = ''
+        hunk = ''
         for n in spec:
-            chunk = chunk + delim + input_lines[n]
-        chunk = chunk.strip(delim)
-        out.append(chunk)
+            hunk = hunk + delim + input_lines[n]
+        hunk = hunk.strip(delim)
+        out.append(hunk)
     return out
 
 
@@ -1488,11 +1497,11 @@ def shuffle_columns_into_dict(
     # Column specs map input lines' columns into outputs.
     # "key", [col1, col2...]
     for spec in column_specs:
-        chunk = ''
+        hunk = ''
         for n in spec[1]:
-            chunk = chunk + delim + input_lines[n]
-        chunk = chunk.strip(delim)
-        out[spec[0]] = chunk
+            hunk = hunk + delim + input_lines[n]
+        hunk = hunk.strip(delim)
+        out[spec[0]] = hunk
     return out
 
 
@@ -1517,9 +1526,9 @@ def to_ascii(x: str):
     b'1, 2, 3'
 
     """
-    if type(x) is str:
+    if isinstance(x, str):
         return x.encode('ascii')
-    if type(x) is bytes:
+    if isinstance(x, bytes):
         return x
     raise Exception('to_ascii works with strings and bytes')