Record streams.
[python_utils.git] / tests / logging_utils_test.py
1 #!/usr/bin/env python3
2
3 import os
4 import sys
5 import tempfile
6 import unittest
7
8 import logging_utils as lutils
9 import string_utils as sutils
10 import unittest_utils as uu
11
12
13 class TestLoggingUtils(unittest.TestCase):
14
15     def test_output_context(self):
16         unique_suffix = sutils.generate_uuid(True)
17         filename = f'/tmp/logging_utils_test.{unique_suffix}'
18         secret_message = f'This is a test, {unique_suffix}.'
19
20         with tempfile.SpooledTemporaryFile(mode='r+') as tmpfile1:
21             with uu.RecordStdout() as record:
22                 with lutils.OutputMultiplexerContext(
23                         lutils.OutputMultiplexer.Destination.FILENAMES |
24                         lutils.OutputMultiplexer.Destination.FILEHANDLES |
25                         lutils.OutputMultiplexer.Destination.LOG_INFO,
26                         filenames = [filename, '/dev/null'],
27                         handles = [tmpfile1, sys.stdout],
28                 ) as mplex:
29                     mplex.print(secret_message, end='')
30
31                 # Make sure it was written to the filename.
32                 with open(filename, 'r') as rf:
33                     self.assertEqual(rf.readline(), secret_message)
34                 os.unlink(filename)
35
36             # Make sure it was written to stdout.
37             tmp = record().readline()
38             self.assertEqual(tmp, secret_message)
39
40             # Make sure it was written to the filehandle.
41             tmpfile1.seek(0)
42             tmp = tmpfile1.readline()
43             self.assertEqual(tmp, secret_message)
44
45     def test_record_streams(self):
46         with uu.RecordMultipleStreams(sys.stderr, sys.stdout) as record:
47             print("This is a test!")
48             print("This is one too.", file=sys.stderr)
49         self.assertEqual(record().readlines(),
50                          ["This is a test!\n", "This is one too.\n"])
51
52
53 if __name__ == '__main__':
54     unittest.main()