005761a5cccf3d5e90bd9ff3020543aed6dcc59c
[python_utils.git] / logging_utils.py
1 #!/usr/bin/env python3
2
3 """Utilities related to logging."""
4
5 import collections
6 import contextlib
7 import datetime
8 import enum
9 import io
10 import logging
11 from logging.handlers import RotatingFileHandler, SysLogHandler
12 import os
13 import pytz
14 import random
15 import sys
16 from typing import Callable, Iterable, Mapping, Optional
17
18 from overrides import overrides
19
20 # This module is commonly used by others in here and should avoid
21 # taking any unnecessary dependencies back on them.
22 import argparse_utils
23 import config
24
25 cfg = config.add_commandline_args(
26     f'Logging ({__file__})',
27     'Args related to logging')
28 cfg.add_argument(
29     '--logging_config_file',
30     type=argparse_utils.valid_filename,
31     default=None,
32     metavar='FILENAME',
33     help='Config file containing the logging setup, see: https://docs.python.org/3/howto/logging.html#logging-advanced-tutorial',
34 )
35 cfg.add_argument(
36     '--logging_level',
37     type=str,
38     default='INFO',
39     choices=['NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
40     metavar='LEVEL',
41     help='The global default level below which to squelch log messages; see also --lmodule',
42 )
43 cfg.add_argument(
44     '--logging_format',
45     type=str,
46     default=None,
47     help='The format for lines logged via the logger module.  See: https://docs.python.org/3/library/logging.html#formatter-objects'
48 )
49 cfg.add_argument(
50     '--logging_date_format',
51     type=str,
52     default='%Y/%m/%dT%H:%M:%S.%f%z',
53     metavar='DATEFMT',
54     help='The format of any dates in --logging_format.'
55 )
56 cfg.add_argument(
57     '--logging_console',
58     action=argparse_utils.ActionNoYes,
59     default=True,
60     help='Should we log to the console (stderr)',
61 )
62 cfg.add_argument(
63     '--logging_filename',
64     type=str,
65     default=None,
66     metavar='FILENAME',
67     help='The filename of the logfile to write.'
68 )
69 cfg.add_argument(
70     '--logging_filename_maxsize',
71     type=int,
72     default=(1024*1024),
73     metavar='#BYTES',
74     help='The maximum size (in bytes) to write to the logging_filename.'
75 )
76 cfg.add_argument(
77     '--logging_filename_count',
78     type=int,
79     default=7,
80     metavar='COUNT',
81     help='The number of logging_filename copies to keep before deleting.'
82 )
83 cfg.add_argument(
84     '--logging_syslog',
85     action=argparse_utils.ActionNoYes,
86     default=False,
87     help='Should we log to localhost\'s syslog.'
88 )
89 cfg.add_argument(
90     '--logging_syslog_facility',
91     type=str,
92     default = 'USER',
93     choices=['NOTSET', 'AUTH', 'AUTH_PRIV', 'CRON', 'DAEMON', 'FTP', 'KERN', 'LPR', 'MAIL', 'NEWS',
94              'SYSLOG', 'USER', 'UUCP', 'LOCAL0', 'LOCAL1', 'LOCAL2', 'LOCAL3', 'LOCAL4', 'LOCAL5',
95              'LOCAL6', 'LOCAL7'],
96     metavar='SYSLOG_FACILITY_LIST',
97     help='The default syslog message facility identifier',
98 )
99 cfg.add_argument(
100     '--logging_debug_threads',
101     action=argparse_utils.ActionNoYes,
102     default=False,
103     help='Should we prepend pid/tid data to all log messages?'
104 )
105 cfg.add_argument(
106     '--logging_debug_modules',
107     action=argparse_utils.ActionNoYes,
108     default=False,
109     help='Should we prepend module/function data to all log messages?'
110 )
111 cfg.add_argument(
112     '--logging_info_is_print',
113     action=argparse_utils.ActionNoYes,
114     default=False,
115     help='logging.info also prints to stdout.'
116 )
117 cfg.add_argument(
118     '--logging_squelch_repeats_enabled',
119     action=argparse_utils.ActionNoYes,
120     default=True,
121     help='Do we allow code to indicate that it wants to squelch repeated logging messages or should we always log?'
122 )
123 cfg.add_argument(
124     '--logging_probabilistically_enabled',
125     action=argparse_utils.ActionNoYes,
126     default=True,
127     help='Do we allow probabilistic logging (for code that wants it) or should we always log?'
128 )
129 # See also: OutputMultiplexer
130 cfg.add_argument(
131     '--logging_captures_prints',
132     action=argparse_utils.ActionNoYes,
133     default=False,
134     help='When calling print, also log.info automatically.'
135 )
136 cfg.add_argument(
137     '--lmodule',
138     type=str,
139     metavar='<SCOPE>=<LEVEL>[,<SCOPE>=<LEVEL>...]',
140     help=(
141         'Allows per-scope logging levels which override the global level set with --logging-level.' +
142         'Pass a space separated list of <scope>=<level> where <scope> is one of: module, ' +
143         'module:function, or :function and <level> is a logging level (e.g. INFO, DEBUG...)'
144     )
145 )
146 cfg.add_argument(
147     '--logging_clear_spammy_handlers',
148     action=argparse_utils.ActionNoYes,
149     default=False,
150     help=(
151         'Should logging code clear preexisting global logging handlers and thus insist that is ' +
152         'alone can add handlers.  Use this to work around annoying modules that insert global ' +
153         'handlers with formats and logging levels you might now want.  Caveat emptor, this may ' +
154         'cause you to miss logging messages.'
155     )
156 )
157
158
159 built_in_print = print
160
161
162 def function_identifier(f: Callable) -> str:
163     """
164     Given a callable function, return a string that identifies it.
165     Usually that string is just __module__:__name__ but there's a
166     corner case: when __module__ is __main__ (i.e. the callable is
167     defined in the same module as __main__).  In this case,
168     f.__module__ returns "__main__" instead of the file that it is
169     defined in.  Work around this using pathlib.Path (see below).
170
171     >>> function_identifier(function_identifier)
172     'logging_utils:function_identifier'
173
174     """
175     if f.__module__ == '__main__':
176         from pathlib import Path
177         import __main__
178         module = __main__.__file__
179         module = Path(module).stem
180         return f'{module}:{f.__name__}'
181     else:
182         return f'{f.__module__}:{f.__name__}'
183
184
185 # A map from logging_callsite_id -> count of logged messages.
186 squelched_logging_counts: Mapping[str, int] = {}
187
188
189 def squelch_repeated_log_messages(squelch_after_n_repeats: int) -> Callable:
190     """
191     A decorator that marks a function as interested in having the logging
192     messages that it produces be squelched (ignored) after it logs the
193     same message more than N times.
194
195     Note: this decorator affects *ALL* logging messages produced
196     within the decorated function.  That said, messages must be
197     identical in order to be squelched.  For example, if the same line
198     of code produces different messages (because of, e.g., a format
199     string), the messages are considered to be different.
200
201     """
202     def squelch_logging_wrapper(f: Callable):
203         identifier = function_identifier(f)
204         squelched_logging_counts[identifier] = squelch_after_n_repeats
205         return f
206     return squelch_logging_wrapper
207
208
209 class SquelchRepeatedMessagesFilter(logging.Filter):
210     """
211     A filter that only logs messages from a given site with the same
212     (exact) message at the same logging level N times and ignores
213     subsequent attempts to log.
214
215     This filter only affects logging messages that repeat more than
216     a threshold number of times from functions that are tagged with
217     the @logging_utils.squelched_logging_ok decorator.
218
219     """
220     def __init__(self) -> None:
221         self.counters = collections.Counter()
222         super().__init__()
223
224     @overrides
225     def filter(self, record: logging.LogRecord) -> bool:
226         id1 = f'{record.module}:{record.funcName}'
227         if id1 not in squelched_logging_counts:
228             return True
229         threshold = squelched_logging_counts[id1]
230         logsite = f'{record.pathname}+{record.lineno}+{record.levelno}+{record.msg}'
231         count = self.counters[logsite]
232         self.counters[logsite] += 1
233         return count < threshold
234
235
236 class DynamicPerScopeLoggingLevelFilter(logging.Filter):
237     """Only interested in seeing logging messages from an allow list of
238     module names or module:function names.  Block others.
239
240     """
241     @staticmethod
242     def level_name_to_level(name: str) -> int:
243         numeric_level = getattr(
244             logging,
245             name,
246             None
247         )
248         if not isinstance(numeric_level, int):
249             raise ValueError('Invalid level: {name}')
250         return numeric_level
251
252     def __init__(
253             self,
254             default_logging_level: int,
255             per_scope_logging_levels: str,
256     ) -> None:
257         super().__init__()
258         self.valid_levels = set(['NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'])
259         self.default_logging_level = default_logging_level
260         self.level_by_scope = {}
261         if per_scope_logging_levels is not None:
262             for chunk in per_scope_logging_levels.split(','):
263                 if '=' not in chunk:
264                     print(
265                         f'Malformed lmodule directive: "{chunk}", missing "=".  Ignored.',
266                         file=sys.stderr
267                     )
268                     continue
269                 try:
270                     (scope, level) = chunk.split('=')
271                 except ValueError:
272                     print(
273                         f'Malformed lmodule directive: "{chunk}".  Ignored.',
274                         file=sys.stderr
275                     )
276                     continue
277                 scope = scope.strip()
278                 level = level.strip().upper()
279                 if level not in self.valid_levels:
280                     print(
281                         f'Malformed lmodule directive: "{chunk}", bad level.  Ignored.',
282                         file=sys.stderr
283                     )
284                     continue
285                 self.level_by_scope[scope] = (
286                     DynamicPerScopeLoggingLevelFilter.level_name_to_level(
287                         level
288                     )
289                 )
290
291     @overrides
292     def filter(self, record: logging.LogRecord) -> bool:
293         # First try to find a logging level by scope (--lmodule)
294         if len(self.level_by_scope) > 0:
295             min_level = None
296             for scope in (
297                     record.module,
298                     f'{record.module}:{record.funcName}',
299                     f':{record.funcName}'
300             ):
301                 level = self.level_by_scope.get(scope, None)
302                 if level is not None:
303                     if min_level is None or level < min_level:
304                         min_level = level
305
306             # If we found one, use it instead of the global default level.
307             if min_level is not None:
308                 return record.levelno >= min_level
309
310         # Otherwise, use the global logging level (--logging_level)
311         return record.levelno >= self.default_logging_level
312
313
314 # A map from function_identifier -> probability of logging (0.0%..100.0%)
315 probabilistic_logging_levels: Mapping[str, float] = {}
316
317
318 def logging_is_probabilistic(probability_of_logging: float) -> Callable:
319     """
320     A decorator that indicates that all logging statements within the
321     scope of a particular (marked) function are not deterministic
322     (i.e. they do not always unconditionally log) but rather are
323     probabilistic (i.e. they log N% of the time randomly).
324
325     This affects *ALL* logging statements within the marked function.
326
327     """
328     def probabilistic_logging_wrapper(f: Callable):
329         identifier = function_identifier(f)
330         probabilistic_logging_levels[identifier] = probability_of_logging
331         return f
332     return probabilistic_logging_wrapper
333
334
335 class ProbabilisticFilter(logging.Filter):
336     """
337     A filter that logs messages probabilistically (i.e. randomly at some
338     percent chance).
339
340     This filter only affects logging messages from functions that have
341     been tagged with the @logging_utils.probabilistic_logging decorator.
342
343     """
344     @overrides
345     def filter(self, record: logging.LogRecord) -> bool:
346         id1 = f'{record.module}:{record.funcName}'
347         if id1 not in probabilistic_logging_levels:
348             return True
349         threshold = probabilistic_logging_levels[id1]
350         return (random.random() * 100.0) <= threshold
351
352
353 class OnlyInfoFilter(logging.Filter):
354     """
355     A filter that only logs messages produced at the INFO logging
356     level.  This is used by the logging_info_is_print commandline
357     option to select a subset of the logging stream to send to a
358     stdout handler.
359
360     """
361     @overrides
362     def filter(self, record: logging.LogRecord):
363         return record.levelno == logging.INFO
364
365
366 class MillisecondAwareFormatter(logging.Formatter):
367     """
368     A formatter for adding milliseconds to log messages.
369
370     """
371     converter = datetime.datetime.fromtimestamp
372
373     @overrides
374     def formatTime(self, record, datefmt=None):
375         ct = MillisecondAwareFormatter.converter(
376             record.created, pytz.timezone("US/Pacific")
377         )
378         if datefmt:
379             s = ct.strftime(datefmt)
380         else:
381             t = ct.strftime("%Y-%m-%d %H:%M:%S")
382             s = "%s,%03d" % (t, record.msecs)
383         return s
384
385
386 def initialize_logging(logger=None) -> logging.Logger:
387     assert config.has_been_parsed()
388     if logger is None:
389         logger = logging.getLogger()       # Root logger
390
391     spammy_handlers = 0
392     if config.config['logging_clear_spammy_handlers']:
393         while logger.hasHandlers():
394             logger.removeHandler(logger.handlers[0])
395             spammy_handlers += 1
396
397     if config.config['logging_config_file'] is not None:
398         logging.config.fileConfig('logging.conf')
399         return logger
400
401     handlers = []
402
403     # Global default logging level (--logging_level)
404     default_logging_level = getattr(
405         logging,
406         config.config['logging_level'].upper(),
407         None
408     )
409     if not isinstance(default_logging_level, int):
410         raise ValueError('Invalid level: %s' % config.config['logging_level'])
411
412     if config.config['logging_format']:
413         fmt = config.config['logging_format']
414     else:
415         if config.config['logging_syslog']:
416             fmt = '%(levelname).1s:%(filename)s[%(process)d]: %(message)s'
417         else:
418             fmt = '%(levelname).1s:%(asctime)s: %(message)s'
419
420     if config.config['logging_debug_threads']:
421         fmt = f'%(process)d.%(thread)d|{fmt}'
422     if config.config['logging_debug_modules']:
423         fmt = f'%(filename)s:%(funcName)s:%(lineno)s|{fmt}'
424
425     if config.config['logging_syslog']:
426         if sys.platform not in ('win32', 'cygwin'):
427             if config.config['logging_syslog_facility']:
428                 facility_name = 'LOG_' + config.config['logging_syslog_facility']
429             facility = SysLogHandler.__dict__.get(facility_name, SysLogHandler.LOG_USER)
430             handler = SysLogHandler(facility=facility, address='/dev/log')
431             handler.setFormatter(
432                 MillisecondAwareFormatter(
433                     fmt=fmt,
434                     datefmt=config.config['logging_date_format'],
435                 )
436             )
437             handlers.append(handler)
438
439     if config.config['logging_filename']:
440         handler = RotatingFileHandler(
441             config.config['logging_filename'],
442             maxBytes = config.config['logging_filename_maxsize'],
443             backupCount = config.config['logging_filename_count'],
444         )
445         handler.setFormatter(
446             MillisecondAwareFormatter(
447                 fmt=fmt,
448                 datefmt=config.config['logging_date_format'],
449             )
450         )
451         handlers.append(handler)
452
453     if config.config['logging_console']:
454         handler = logging.StreamHandler(sys.stderr)
455         handler.setFormatter(
456             MillisecondAwareFormatter(
457                 fmt=fmt,
458                 datefmt=config.config['logging_date_format'],
459             )
460         )
461         handlers.append(handler)
462
463     if len(handlers) == 0:
464         handlers.append(logging.NullHandler())
465
466     for handler in handlers:
467         logger.addHandler(handler)
468
469     if config.config['logging_info_is_print']:
470         handler = logging.StreamHandler(sys.stdout)
471         handler.addFilter(OnlyInfoFilter())
472         logger.addHandler(handler)
473
474     if config.config['logging_squelch_repeats_enabled']:
475         for handler in handlers:
476             handler.addFilter(SquelchRepeatedMessagesFilter())
477
478     if config.config['logging_probabilistically_enabled']:
479         for handler in handlers:
480             handler.addFilter(ProbabilisticFilter())
481
482     for handler in handlers:
483         handler.addFilter(
484             DynamicPerScopeLoggingLevelFilter(
485                 default_logging_level,
486                 config.config['lmodule'],
487             )
488         )
489     logger.setLevel(0)
490     logger.propagate = False
491
492     if config.config['logging_captures_prints']:
493         import builtins
494         global built_in_print
495
496         def print_and_also_log(*arg, **kwarg):
497             f = kwarg.get('file', None)
498             if f == sys.stderr:
499                 logger.warning(*arg)
500             else:
501                 logger.info(*arg)
502             built_in_print(*arg, **kwarg)
503         builtins.print = print_and_also_log
504
505     logger.debug(f'Initialized logger; default logging level is {default_logging_level}.')
506     if config.config['logging_clear_spammy_handlers'] and spammy_handlers > 0:
507         logger.warning(
508             'Logging cleared {spammy_handlers} global handlers (--logging_clear_spammy_handlers)'
509         )
510     logger.debug(f'Logging format is "{fmt}"')
511     if config.config['logging_syslog']:
512         logger.debug(f'Logging to syslog as {facility_name} with normal severity mapping')
513     if config.config['logging_filename']:
514         logger.debug(f'Logging to filename {config.config["logging_filename"]} with rotation')
515     if config.config['logging_console']:
516         logger.debug(f'Logging to the console.')
517     if config.config['logging_info_is_print']:
518         logger.debug(
519             'Logging logger.info messages will be repeated on stdout (--logging_info_is_print)'
520         )
521     if config.config['logging_squelch_repeats_enabled']:
522         logger.debug(
523             'Logging code is allowed to request repeated messages be squelched (--logging_squelch_repeats_enabled)'
524         )
525     if config.config['logging_probabilistically_enabled']:
526         logger.debug(
527             'Logging code is allowed to request probabilistic logging (--logging_probabilistically_enabled)'
528         )
529     if config.config['lmodule']:
530         logger.debug(
531             'Logging dynamic per-module logging enabled (--lmodule={config.config["lmodule"]})'
532         )
533     if config.config['logging_captures_prints']:
534         logger.debug('Logging will capture printed messages (--logging_captures_prints)')
535     return logger
536
537
538 def get_logger(name: str = ""):
539     logger = logging.getLogger(name)
540     return initialize_logging(logger)
541
542
543 def tprint(*args, **kwargs) -> None:
544     """Legacy function for printing a message augmented with thread id."""
545
546     if config.config['logging_debug_threads']:
547         from thread_utils import current_thread_id
548         print(f'{current_thread_id()}', end="")
549         print(*args, **kwargs)
550     else:
551         pass
552
553
554 def dprint(*args, **kwargs) -> None:
555     """Legacy function used to print to stderr."""
556
557     print(*args, file=sys.stderr, **kwargs)
558
559
560 class OutputMultiplexer(object):
561     """
562     A class that broadcasts printed messages to several sinks (including
563     various logging levels, different files, different file handles,
564     the house log, etc...)
565
566     """
567     class Destination(enum.IntEnum):
568         """Bits in the destination_bitv bitvector.  Used to indicate the
569         output destination."""
570         LOG_DEBUG = 0x01         #  ⎫
571         LOG_INFO = 0x02          #  ⎪
572         LOG_WARNING = 0x04       #  ⎬ Must provide logger to the c'tor.
573         LOG_ERROR = 0x08         #  ⎪
574         LOG_CRITICAL = 0x10      #  ⎭
575         FILENAMES = 0x20         # Must provide a filename to the c'tor.
576         FILEHANDLES = 0x40       # Must provide a handle to the c'tor.
577         HLOG = 0x80
578         ALL_LOG_DESTINATIONS = (
579             LOG_DEBUG | LOG_INFO | LOG_WARNING | LOG_ERROR | LOG_CRITICAL
580         )
581         ALL_OUTPUT_DESTINATIONS = 0x8F
582
583     def __init__(self,
584                  destination_bitv: int,
585                  *,
586                  logger=None,
587                  filenames: Optional[Iterable[str]] = None,
588                  handles: Optional[Iterable[io.TextIOWrapper]] = None):
589         if logger is None:
590             logger = logging.getLogger(None)
591         self.logger = logger
592
593         if filenames is not None:
594             self.f = [
595                 open(filename, 'wb', buffering=0) for filename in filenames
596             ]
597         else:
598             if destination_bitv & OutputMultiplexer.FILENAMES:
599                 raise ValueError(
600                     "Filenames argument is required if bitv & FILENAMES"
601                 )
602             self.f = None
603
604         if handles is not None:
605             self.h = [handle for handle in handles]
606         else:
607             if destination_bitv & OutputMultiplexer.Destination.FILEHANDLES:
608                 raise ValueError(
609                     "Handle argument is required if bitv & FILEHANDLES"
610                 )
611             self.h = None
612
613         self.set_destination_bitv(destination_bitv)
614
615     def get_destination_bitv(self):
616         return self.destination_bitv
617
618     def set_destination_bitv(self, destination_bitv: int):
619         if destination_bitv & self.Destination.FILENAMES and self.f is None:
620             raise ValueError(
621                 "Filename argument is required if bitv & FILENAMES"
622             )
623         if destination_bitv & self.Destination.FILEHANDLES and self.h is None:
624             raise ValueError(
625                     "Handle argument is required if bitv & FILEHANDLES"
626                 )
627         self.destination_bitv = destination_bitv
628
629     def print(self, *args, **kwargs):
630         from string_utils import sprintf, strip_escape_sequences
631         end = kwargs.pop("end", None)
632         if end is not None:
633             if not isinstance(end, str):
634                 raise TypeError("end must be None or a string")
635         sep = kwargs.pop("sep", None)
636         if sep is not None:
637             if not isinstance(sep, str):
638                 raise TypeError("sep must be None or a string")
639         if kwargs:
640             raise TypeError("invalid keyword arguments to print()")
641         buf = sprintf(*args, end="", sep=sep)
642         if sep is None:
643             sep = " "
644         if end is None:
645             end = "\n"
646         if end == '\n':
647             buf += '\n'
648         if (
649                 self.destination_bitv & self.Destination.FILENAMES and
650                 self.f is not None
651         ):
652             for _ in self.f:
653                 _.write(buf.encode('utf-8'))
654                 _.flush()
655
656         if (
657                 self.destination_bitv & self.Destination.FILEHANDLES and
658                 self.h is not None
659         ):
660             for _ in self.h:
661                 _.write(buf)
662                 _.flush()
663
664         buf = strip_escape_sequences(buf)
665         if self.logger is not None:
666             if self.destination_bitv & self.Destination.LOG_DEBUG:
667                 self.logger.debug(buf)
668             if self.destination_bitv & self.Destination.LOG_INFO:
669                 self.logger.info(buf)
670             if self.destination_bitv & self.Destination.LOG_WARNING:
671                 self.logger.warning(buf)
672             if self.destination_bitv & self.Destination.LOG_ERROR:
673                 self.logger.error(buf)
674             if self.destination_bitv & self.Destination.LOG_CRITICAL:
675                 self.logger.critical(buf)
676         if self.destination_bitv & self.Destination.HLOG:
677             hlog(buf)
678
679     def close(self):
680         if self.f is not None:
681             for _ in self.f:
682                 _.close()
683
684
685 class OutputMultiplexerContext(OutputMultiplexer, contextlib.ContextDecorator):
686     """
687     A context that uses an OutputMultiplexer.  e.g.
688
689         with OutputMultiplexerContext(
690                 OutputMultiplexer.LOG_INFO |
691                 OutputMultiplexer.LOG_DEBUG |
692                 OutputMultiplexer.FILENAMES |
693                 OutputMultiplexer.FILEHANDLES,
694                 filenames = [ '/tmp/foo.log', '/var/log/bar.log' ],
695                 handles = [ f, g ]
696             ) as mplex:
697                 mplex.print("This is a log message!")
698
699     """
700     def __init__(self,
701                  destination_bitv: OutputMultiplexer.Destination,
702                  *,
703                  logger = None,
704                  filenames = None,
705                  handles = None):
706         super().__init__(
707             destination_bitv,
708             logger=logger,
709             filenames=filenames,
710             handles=handles)
711
712     def __enter__(self):
713         return self
714
715     def __exit__(self, etype, value, traceback) -> bool:
716         super().close()
717         if etype is not None:
718             return False
719         return True
720
721
722 def hlog(message: str) -> None:
723     """Write a message to the house log."""
724
725     message = message.replace("'", "'\"'\"'")
726     os.system(f"/usr/bin/logger -p local7.info -- '{message}'")
727
728
729 if __name__ == '__main__':
730     import doctest
731     doctest.testmod()