Adding doctests. Also added a logging filter.
[python_utils.git] / logging_utils.py
index 25919a765ef2430283cb0e67572d326ca62507f0..034f90c0ee3ab932ffbf92740017fed369884b55 100644 (file)
@@ -2,6 +2,7 @@
 
 """Utilities related to logging."""
 
+import collections
 import contextlib
 import datetime
 import enum
@@ -94,6 +95,12 @@ cfg.add_argument(
     default=False,
     help='logging.info also prints to stdout.'
 )
+cfg.add_argument(
+    '--logging_max_n_times_per_message',
+    type=int,
+    default=0,
+    help='When set, ignore logged messages from the same site after N.'
+)
 
 # See also: OutputMultiplexer
 cfg.add_argument(
@@ -107,11 +114,37 @@ built_in_print = print
 
 
 class OnlyInfoFilter(logging.Filter):
+    """
+    A filter that only logs messages produced at the INFO logging level.
+    """
     def filter(self, record):
         return record.levelno == logging.INFO
 
 
+class OnlyNTimesFilter(logging.Filter):
+    """
+    A filter that only logs messages from a given site with the same
+    message at the same logging level N times and ignores subsequent
+    attempts to log.
+
+    """
+    def __init__(self, maximum: int) -> None:
+        self.maximum = maximum
+        self.counters = collections.Counter()
+        super().__init__()
+
+    def filter(self, record: logging.LogRecord) -> bool:
+        source = f'{record.pathname}+{record.lineno}+{record.levelno}+{record.msg}'
+        count = self.counters[source]
+        self.counters[source] += 1
+        return count < self.maximum
+
+
 class MillisecondAwareFormatter(logging.Formatter):
+    """
+    A formatter for adding milliseconds to log messages.
+
+    """
     converter = datetime.datetime.fromtimestamp
 
     def formatTime(self, record, datefmt=None):
@@ -199,6 +232,11 @@ def initialize_logging(logger=None) -> logging.Logger:
         handler.addFilter(OnlyInfoFilter())
         logger.addHandler(handler)
 
+    maximum = config.config['logging_max_n_times_per_message']
+    if maximum > 0:
+        for handler in handlers:
+            handler.addFilter(OnlyNTimesFilter(maximum))
+
     logger.setLevel(numeric_level)
     logger.propagate = False