Ugh, a bunch of things. @overrides. --lmodule. Chromecasts. etc...
[python_utils.git] / logging_utils.py
1 #!/usr/bin/env python3
2
3 """Utilities related to logging."""
4
5 import collections
6 import contextlib
7 import datetime
8 import enum
9 import io
10 import logging
11 from logging.handlers import RotatingFileHandler, SysLogHandler
12 import os
13 import pytz
14 import random
15 import sys
16 from typing import Callable, Iterable, Mapping, Optional
17
18 from overrides import overrides
19
20 # This module is commonly used by others in here and should avoid
21 # taking any unnecessary dependencies back on them.
22 import argparse_utils
23 import config
24
25 cfg = config.add_commandline_args(
26     f'Logging ({__file__})',
27     'Args related to logging')
28 cfg.add_argument(
29     '--logging_config_file',
30     type=argparse_utils.valid_filename,
31     default=None,
32     metavar='FILENAME',
33     help='Config file containing the logging setup, see: https://docs.python.org/3/howto/logging.html#logging-advanced-tutorial',
34 )
35 cfg.add_argument(
36     '--logging_level',
37     type=str,
38     default='INFO',
39     choices=['NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
40     metavar='LEVEL',
41     help='The global default level below which to squelch log messages; see also --lmodule',
42 )
43 cfg.add_argument(
44     '--logging_format',
45     type=str,
46     default='%(levelname).1s:%(asctime)s: %(message)s',
47     help='The format for lines logged via the logger module.'
48 )
49 cfg.add_argument(
50     '--logging_date_format',
51     type=str,
52     default='%Y/%m/%dT%H:%M:%S.%f%z',
53     metavar='DATEFMT',
54     help='The format of any dates in --logging_format.'
55 )
56 cfg.add_argument(
57     '--logging_console',
58     action=argparse_utils.ActionNoYes,
59     default=True,
60     help='Should we log to the console (stderr)',
61 )
62 cfg.add_argument(
63     '--logging_filename',
64     type=str,
65     default=None,
66     metavar='FILENAME',
67     help='The filename of the logfile to write.'
68 )
69 cfg.add_argument(
70     '--logging_filename_maxsize',
71     type=int,
72     default=(1024*1024),
73     metavar='#BYTES',
74     help='The maximum size (in bytes) to write to the logging_filename.'
75 )
76 cfg.add_argument(
77     '--logging_filename_count',
78     type=int,
79     default=2,
80     metavar='COUNT',
81     help='The number of logging_filename copies to keep before deleting.'
82 )
83 cfg.add_argument(
84     '--logging_syslog',
85     action=argparse_utils.ActionNoYes,
86     default=False,
87     help='Should we log to localhost\'s syslog.'
88 )
89 cfg.add_argument(
90     '--logging_debug_threads',
91     action=argparse_utils.ActionNoYes,
92     default=False,
93     help='Should we prepend pid/tid data to all log messages?'
94 )
95 cfg.add_argument(
96     '--logging_info_is_print',
97     action=argparse_utils.ActionNoYes,
98     default=False,
99     help='logging.info also prints to stdout.'
100 )
101 cfg.add_argument(
102     '--logging_squelch_repeats_enabled',
103     action=argparse_utils.ActionNoYes,
104     default=True,
105     help='Do we allow code to indicate that it wants to squelch repeated logging messages or should we always log?'
106 )
107 cfg.add_argument(
108     '--logging_probabilistically_enabled',
109     action=argparse_utils.ActionNoYes,
110     default=True,
111     help='Do we allow probabilistic logging (for code that wants it) or should we always log?'
112 )
113 # See also: OutputMultiplexer
114 cfg.add_argument(
115     '--logging_captures_prints',
116     action=argparse_utils.ActionNoYes,
117     default=False,
118     help='When calling print, also log.info automatically.'
119 )
120 cfg.add_argument(
121     '--lmodule',
122     type=str,
123     metavar='<SCOPE>=<LEVEL>[,<SCOPE>=<LEVEL>...]',
124     help=(
125         'Allows per-scope logging levels which override the global level set with --logging-level.' +
126         'Pass a space separated list of <scope>=<level> where <scope> is one of: module, ' +
127         'module:function, or :function and <level> is a logging level (e.g. INFO, DEBUG...)'
128     )
129 )
130
131
132 built_in_print = print
133
134
135 def function_identifier(f: Callable) -> str:
136     """
137     Given a callable function, return a string that identifies it.
138     Usually that string is just __module__:__name__ but there's a
139     corner case: when __module__ is __main__ (i.e. the callable is
140     defined in the same module as __main__).  In this case,
141     f.__module__ returns "__main__" instead of the file that it is
142     defined in.  Work around this using pathlib.Path (see below).
143
144     >>> function_identifier(function_identifier)
145     'logging_utils:function_identifier'
146
147     """
148     if f.__module__ == '__main__':
149         from pathlib import Path
150         import __main__
151         module = __main__.__file__
152         module = Path(module).stem
153         return f'{module}:{f.__name__}'
154     else:
155         return f'{f.__module__}:{f.__name__}'
156
157
158 # A map from logging_callsite_id -> count of logged messages.
159 squelched_logging_counts: Mapping[str, int] = {}
160
161
162 def squelch_repeated_log_messages(squelch_after_n_repeats: int) -> Callable:
163     """
164     A decorator that marks a function as interested in having the logging
165     messages that it produces be squelched (ignored) after it logs the
166     same message more than N times.
167
168     Note: this decorator affects *ALL* logging messages produced
169     within the decorated function.  That said, messages must be
170     identical in order to be squelched.  For example, if the same line
171     of code produces different messages (because of, e.g., a format
172     string), the messages are considered to be different.
173
174     """
175     def squelch_logging_wrapper(f: Callable):
176         identifier = function_identifier(f)
177         squelched_logging_counts[identifier] = squelch_after_n_repeats
178         return f
179     return squelch_logging_wrapper
180
181
182 class SquelchRepeatedMessagesFilter(logging.Filter):
183     """
184     A filter that only logs messages from a given site with the same
185     (exact) message at the same logging level N times and ignores
186     subsequent attempts to log.
187
188     This filter only affects logging messages that repeat more than
189     a threshold number of times from functions that are tagged with
190     the @logging_utils.squelched_logging_ok decorator.
191
192     """
193     def __init__(self) -> None:
194         self.counters = collections.Counter()
195         super().__init__()
196
197     @overrides
198     def filter(self, record: logging.LogRecord) -> bool:
199         id1 = f'{record.module}:{record.funcName}'
200         if id1 not in squelched_logging_counts:
201             return True
202         threshold = squelched_logging_counts[id1]
203         logsite = f'{record.pathname}+{record.lineno}+{record.levelno}+{record.msg}'
204         count = self.counters[logsite]
205         self.counters[logsite] += 1
206         return count < threshold
207
208
209 class DynamicPerScopeLoggingLevelFilter(logging.Filter):
210     """Only interested in seeing logging messages from an allow list of
211     module names or module:function names.  Block others.
212
213     """
214     @staticmethod
215     def level_name_to_level(name: str) -> int:
216         numeric_level = getattr(
217             logging,
218             name,
219             None
220         )
221         if not isinstance(numeric_level, int):
222             raise ValueError('Invalid level: {name}')
223         return numeric_level
224
225     def __init__(
226             self,
227             default_logging_level: int,
228             per_scope_logging_levels: str,
229     ) -> None:
230         super().__init__()
231         self.valid_levels = set(['NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'])
232         self.default_logging_level = default_logging_level
233         self.level_by_scope = {}
234         if per_scope_logging_levels is not None:
235             for chunk in per_scope_logging_levels.split(','):
236                 if '=' not in chunk:
237                     print(
238                         f'Malformed lmodule directive: "{chunk}", missing "=".  Ignored.',
239                         file=sys.stderr
240                     )
241                     continue
242                 try:
243                     (scope, level) = chunk.split('=')
244                 except ValueError:
245                     print(
246                         f'Malformed lmodule directive: "{chunk}".  Ignored.',
247                         file=sys.stderr
248                     )
249                     continue
250                 scope = scope.strip()
251                 level = level.strip().upper()
252                 if level not in self.valid_levels:
253                     print(
254                         f'Malformed lmodule directive: "{chunk}", bad level.  Ignored.',
255                         file=sys.stderr
256                     )
257                     continue
258                 self.level_by_scope[scope] = (
259                     DynamicPerScopeLoggingLevelFilter.level_name_to_level(
260                         level
261                     )
262                 )
263
264     @overrides
265     def filter(self, record: logging.LogRecord) -> bool:
266         # First try to find a logging level by scope (--lmodule)
267         if len(self.level_by_scope) > 0:
268             min_level = None
269             for scope in (
270                     record.module,
271                     f'{record.module}:{record.funcName}',
272                     f':{record.funcName}'
273             ):
274                 level = self.level_by_scope.get(scope, None)
275                 if level is not None:
276                     if min_level is None or level < min_level:
277                         min_level = level
278
279             # If we found one, use it instead of the global default level.
280             if min_level is not None:
281                 return record.levelno >= min_level
282
283         # Otherwise, use the global logging level (--logging_level)
284         return record.levelno >= self.default_logging_level
285
286
287 # A map from function_identifier -> probability of logging (0.0%..100.0%)
288 probabilistic_logging_levels: Mapping[str, float] = {}
289
290
291 def logging_is_probabilistic(probability_of_logging: float) -> Callable:
292     """
293     A decorator that indicates that all logging statements within the
294     scope of a particular (marked) function are not deterministic
295     (i.e. they do not always unconditionally log) but rather are
296     probabilistic (i.e. they log N% of the time randomly).
297
298     This affects *ALL* logging statements within the marked function.
299
300     """
301     def probabilistic_logging_wrapper(f: Callable):
302         identifier = function_identifier(f)
303         probabilistic_logging_levels[identifier] = probability_of_logging
304         return f
305     return probabilistic_logging_wrapper
306
307
308 class ProbabilisticFilter(logging.Filter):
309     """
310     A filter that logs messages probabilistically (i.e. randomly at some
311     percent chance).
312
313     This filter only affects logging messages from functions that have
314     been tagged with the @logging_utils.probabilistic_logging decorator.
315
316     """
317     @overrides
318     def filter(self, record: logging.LogRecord) -> bool:
319         id1 = f'{record.module}:{record.funcName}'
320         if id1 not in probabilistic_logging_levels:
321             return True
322         threshold = probabilistic_logging_levels[id1]
323         return (random.random() * 100.0) <= threshold
324
325
326 class OnlyInfoFilter(logging.Filter):
327     """
328     A filter that only logs messages produced at the INFO logging
329     level.  This is used by the logging_info_is_print commandline
330     option to select a subset of the logging stream to send to a
331     stdout handler.
332
333     """
334     @overrides
335     def filter(self, record: logging.LogRecord):
336         return record.levelno == logging.INFO
337
338
339 class MillisecondAwareFormatter(logging.Formatter):
340     """
341     A formatter for adding milliseconds to log messages.
342
343     """
344     converter = datetime.datetime.fromtimestamp
345
346     @overrides
347     def formatTime(self, record, datefmt=None):
348         ct = MillisecondAwareFormatter.converter(
349             record.created, pytz.timezone("US/Pacific")
350         )
351         if datefmt:
352             s = ct.strftime(datefmt)
353         else:
354             t = ct.strftime("%Y-%m-%d %H:%M:%S")
355             s = "%s,%03d" % (t, record.msecs)
356         return s
357
358
359 def initialize_logging(logger=None) -> logging.Logger:
360     assert config.has_been_parsed()
361     if logger is None:
362         logger = logging.getLogger()       # Root logger
363
364     if config.config['logging_config_file'] is not None:
365         logging.config.fileConfig('logging.conf')
366         return logger
367
368     handlers = []
369
370     # Global default logging level (--logging_level)
371     default_logging_level = getattr(
372         logging,
373         config.config['logging_level'].upper(),
374         None
375     )
376     if not isinstance(default_logging_level, int):
377         raise ValueError('Invalid level: %s' % config.config['logging_level'])
378
379     fmt = config.config['logging_format']
380     if config.config['logging_debug_threads']:
381         fmt = f'%(process)d.%(thread)d|{fmt}'
382
383     if config.config['logging_syslog']:
384         if sys.platform not in ('win32', 'cygwin'):
385             handler = SysLogHandler()
386             handler.setFormatter(
387                 MillisecondAwareFormatter(
388                     fmt=fmt,
389                     datefmt=config.config['logging_date_format'],
390                 )
391             )
392             handlers.append(handler)
393
394     if config.config['logging_filename']:
395         handler = RotatingFileHandler(
396             config.config['logging_filename'],
397             maxBytes = config.config['logging_filename_maxsize'],
398             backupCount = config.config['logging_filename_count'],
399         )
400         handler.setFormatter(
401             MillisecondAwareFormatter(
402                 fmt=fmt,
403                 datefmt=config.config['logging_date_format'],
404             )
405         )
406         handlers.append(handler)
407
408     if config.config['logging_console']:
409         handler = logging.StreamHandler(sys.stderr)
410         handler.setFormatter(
411             MillisecondAwareFormatter(
412                 fmt=fmt,
413                 datefmt=config.config['logging_date_format'],
414             )
415         )
416         handlers.append(handler)
417
418     if len(handlers) == 0:
419         handlers.append(logging.NullHandler())
420
421     for handler in handlers:
422         logger.addHandler(handler)
423
424     if config.config['logging_info_is_print']:
425         handler = logging.StreamHandler(sys.stdout)
426         handler.addFilter(OnlyInfoFilter())
427         logger.addHandler(handler)
428
429     if config.config['logging_squelch_repeats_enabled']:
430         for handler in handlers:
431             handler.addFilter(SquelchRepeatedMessagesFilter())
432
433     if config.config['logging_probabilistically_enabled']:
434         for handler in handlers:
435             handler.addFilter(ProbabilisticFilter())
436
437     for handler in handlers:
438         handler.addFilter(
439             DynamicPerScopeLoggingLevelFilter(
440                 default_logging_level,
441                 config.config['lmodule'],
442             )
443         )
444     logger.setLevel(0)
445     logger.propagate = False
446
447     if config.config['logging_captures_prints']:
448         import builtins
449         global built_in_print
450
451         def print_and_also_log(*arg, **kwarg):
452             f = kwarg.get('file', None)
453             if f == sys.stderr:
454                 logger.warning(*arg)
455             else:
456                 logger.info(*arg)
457             built_in_print(*arg, **kwarg)
458         builtins.print = print_and_also_log
459
460     return logger
461
462
463 def get_logger(name: str = ""):
464     logger = logging.getLogger(name)
465     return initialize_logging(logger)
466
467
468 def tprint(*args, **kwargs) -> None:
469     """Legacy function for printing a message augmented with thread id."""
470
471     if config.config['logging_debug_threads']:
472         from thread_utils import current_thread_id
473         print(f'{current_thread_id()}', end="")
474         print(*args, **kwargs)
475     else:
476         pass
477
478
479 def dprint(*args, **kwargs) -> None:
480     """Legacy function used to print to stderr."""
481
482     print(*args, file=sys.stderr, **kwargs)
483
484
485 class OutputMultiplexer(object):
486     """
487     A class that broadcasts printed messages to several sinks (including
488     various logging levels, different files, different file handles,
489     the house log, etc...)
490
491     """
492     class Destination(enum.IntEnum):
493         """Bits in the destination_bitv bitvector.  Used to indicate the
494         output destination."""
495         LOG_DEBUG = 0x01         #  ⎫
496         LOG_INFO = 0x02          #  ⎪
497         LOG_WARNING = 0x04       #  ⎬ Must provide logger to the c'tor.
498         LOG_ERROR = 0x08         #  ⎪
499         LOG_CRITICAL = 0x10      #  ⎭
500         FILENAMES = 0x20         # Must provide a filename to the c'tor.
501         FILEHANDLES = 0x40       # Must provide a handle to the c'tor.
502         HLOG = 0x80
503         ALL_LOG_DESTINATIONS = (
504             LOG_DEBUG | LOG_INFO | LOG_WARNING | LOG_ERROR | LOG_CRITICAL
505         )
506         ALL_OUTPUT_DESTINATIONS = 0x8F
507
508     def __init__(self,
509                  destination_bitv: int,
510                  *,
511                  logger=None,
512                  filenames: Optional[Iterable[str]] = None,
513                  handles: Optional[Iterable[io.TextIOWrapper]] = None):
514         if logger is None:
515             logger = logging.getLogger(None)
516         self.logger = logger
517
518         if filenames is not None:
519             self.f = [
520                 open(filename, 'wb', buffering=0) for filename in filenames
521             ]
522         else:
523             if destination_bitv & OutputMultiplexer.FILENAMES:
524                 raise ValueError(
525                     "Filenames argument is required if bitv & FILENAMES"
526                 )
527             self.f = None
528
529         if handles is not None:
530             self.h = [handle for handle in handles]
531         else:
532             if destination_bitv & OutputMultiplexer.Destination.FILEHANDLES:
533                 raise ValueError(
534                     "Handle argument is required if bitv & FILEHANDLES"
535                 )
536             self.h = None
537
538         self.set_destination_bitv(destination_bitv)
539
540     def get_destination_bitv(self):
541         return self.destination_bitv
542
543     def set_destination_bitv(self, destination_bitv: int):
544         if destination_bitv & self.Destination.FILENAMES and self.f is None:
545             raise ValueError(
546                 "Filename argument is required if bitv & FILENAMES"
547             )
548         if destination_bitv & self.Destination.FILEHANDLES and self.h is None:
549             raise ValueError(
550                     "Handle argument is required if bitv & FILEHANDLES"
551                 )
552         self.destination_bitv = destination_bitv
553
554     def print(self, *args, **kwargs):
555         from string_utils import sprintf, strip_escape_sequences
556         end = kwargs.pop("end", None)
557         if end is not None:
558             if not isinstance(end, str):
559                 raise TypeError("end must be None or a string")
560         sep = kwargs.pop("sep", None)
561         if sep is not None:
562             if not isinstance(sep, str):
563                 raise TypeError("sep must be None or a string")
564         if kwargs:
565             raise TypeError("invalid keyword arguments to print()")
566         buf = sprintf(*args, end="", sep=sep)
567         if sep is None:
568             sep = " "
569         if end is None:
570             end = "\n"
571         if end == '\n':
572             buf += '\n'
573         if (
574                 self.destination_bitv & self.Destination.FILENAMES and
575                 self.f is not None
576         ):
577             for _ in self.f:
578                 _.write(buf.encode('utf-8'))
579                 _.flush()
580
581         if (
582                 self.destination_bitv & self.Destination.FILEHANDLES and
583                 self.h is not None
584         ):
585             for _ in self.h:
586                 _.write(buf)
587                 _.flush()
588
589         buf = strip_escape_sequences(buf)
590         if self.logger is not None:
591             if self.destination_bitv & self.Destination.LOG_DEBUG:
592                 self.logger.debug(buf)
593             if self.destination_bitv & self.Destination.LOG_INFO:
594                 self.logger.info(buf)
595             if self.destination_bitv & self.Destination.LOG_WARNING:
596                 self.logger.warning(buf)
597             if self.destination_bitv & self.Destination.LOG_ERROR:
598                 self.logger.error(buf)
599             if self.destination_bitv & self.Destination.LOG_CRITICAL:
600                 self.logger.critical(buf)
601         if self.destination_bitv & self.Destination.HLOG:
602             hlog(buf)
603
604     def close(self):
605         if self.f is not None:
606             for _ in self.f:
607                 _.close()
608
609
610 class OutputMultiplexerContext(OutputMultiplexer, contextlib.ContextDecorator):
611     """
612     A context that uses an OutputMultiplexer.  e.g.
613
614         with OutputMultiplexerContext(
615                 OutputMultiplexer.LOG_INFO |
616                 OutputMultiplexer.LOG_DEBUG |
617                 OutputMultiplexer.FILENAMES |
618                 OutputMultiplexer.FILEHANDLES,
619                 filenames = [ '/tmp/foo.log', '/var/log/bar.log' ],
620                 handles = [ f, g ]
621             ) as mplex:
622                 mplex.print("This is a log message!")
623
624     """
625     def __init__(self,
626                  destination_bitv: OutputMultiplexer.Destination,
627                  *,
628                  logger = None,
629                  filenames = None,
630                  handles = None):
631         super().__init__(
632             destination_bitv,
633             logger=logger,
634             filenames=filenames,
635             handles=handles)
636
637     def __enter__(self):
638         return self
639
640     def __exit__(self, etype, value, traceback) -> bool:
641         super().close()
642         if etype is not None:
643             return False
644         return True
645
646
647 def hlog(message: str) -> None:
648     """Write a message to the house log."""
649
650     message = message.replace("'", "'\"'\"'")
651     os.system(f"/usr/bin/logger -p local7.info -- '{message}'")
652
653
654 if __name__ == '__main__':
655     import doctest
656     doctest.testmod()