Smart outlets
[python_utils.git] / ansi.py
diff --git a/ansi.py b/ansi.py
index 4c580c0a0262ad133a71f6def6ca63a7d003dd8f..4c09db3df76c1c66561f097e6c48d48caa62156c 100755 (executable)
--- a/ansi.py
+++ b/ansi.py
@@ -1,9 +1,14 @@
 #!/usr/bin/env python3
 
+from abc import abstractmethod
 import difflib
+import io
 import logging
+import re
 import sys
-from typing import Dict, Optional, Tuple
+from typing import Any, Callable, Dict, Iterable, Optional, Tuple
+
+import logging_utils
 
 logger = logging.getLogger(__name__)
 
@@ -1723,6 +1728,7 @@ def _find_color_by_name(name: str) -> Tuple[int, int, int]:
     return rgb
 
 
+@logging_utils.squelch_repeated_log_messages(1)
 def fg(name: Optional[str] = "",
        red: Optional[int] = None,
        green: Optional[int] = None,
@@ -1732,6 +1738,9 @@ def fg(name: Optional[str] = "",
        force_216color: bool = False) -> str:
     import string_utils
 
+    if name is not None and name == 'reset':
+        return '\033[39m'
+
     if name is not None and string_utils.is_full_string(name):
         rgb = _find_color_by_name(name)
         return fg(
@@ -1816,6 +1825,9 @@ def bg(name: Optional[str] = "",
        force_216color: bool = False) -> str:
     import string_utils
 
+    if name is not None and name == 'reset':
+        return '\033[49m'
+
     if name is not None and string_utils.is_full_string(name):
         rgb = _find_color_by_name(name)
         return bg(
@@ -1846,6 +1858,37 @@ def bg(name: Optional[str] = "",
     return bg_24bit(red, green, blue)
 
 
+class StdoutInterceptor(io.TextIOBase):
+    def __init__(self):
+        self.saved_stdout: Optional[io.TextIOBase] = None
+        self.buf = ''
+
+    @abstractmethod
+    def write(self, s):
+        pass
+
+    def __enter__(self) -> None:
+        self.saved_stdout = sys.stdout
+        sys.stdout = self
+        return None
+
+    def __exit__(self, *args) -> bool:
+        sys.stdout = self.saved_stdout
+        print(self.buf)
+        return None
+
+
+class ProgrammableColorizer(StdoutInterceptor):
+    def __init__(self, patterns: Iterable[Tuple[re.Pattern, Callable[[Any, re.Pattern], str]]]):
+        super().__init__()
+        self.patterns = [_ for _ in patterns]
+
+    def write(self, s: str):
+        for pattern in self.patterns:
+            s = pattern[0].sub(pattern[1], s)
+        self.buf += s
+
+
 if __name__ == '__main__':
     def main() -> None:
         name = " ".join(sys.argv[1:])