Record streams.
[python_utils.git] / tests / logging_utils_test.py
index 48e5015f12f1f0af1595c3e1aa026bdcb9236123..87c00d64e30e3a9b0a90e08658742ad0ae2739aa 100755 (executable)
@@ -1,12 +1,13 @@
 #!/usr/bin/env python3
 
-import contextlib
+import os
+import sys
 import tempfile
 import unittest
 
-import bootstrap
 import logging_utils as lutils
 import string_utils as sutils
+import unittest_utils as uu
 
 
 class TestLoggingUtils(unittest.TestCase):
@@ -17,26 +18,37 @@ class TestLoggingUtils(unittest.TestCase):
         secret_message = f'This is a test, {unique_suffix}.'
 
         with tempfile.SpooledTemporaryFile(mode='r+') as tmpfile1:
-            with tempfile.SpooledTemporaryFile(mode='r+') as tmpfile2:
-                with contextlib.redirect_stdout(tmpfile1):
-                    with lutils.OutputContext(
-                            lutils.OutputMultiplexer.Destination.FILENAME |
-                            lutils.OutputMultiplexer.Destination.FILEHANDLE |
-                            lutils.OutputMultiplexer.Destination.STDOUT,
-                            filename = filename,
-                            handle = tmpfile2,
-                    ) as mplex:
-                        mplex.print(secret_message, end='')
-                    with open(filename, 'r') as rf:
-                        self.assertEqual(rf.readline(), secret_message)
-                tmpfile2.seek(0)
-                tmp = tmpfile2.readline()
-                self.assertEqual(tmp, secret_message)
+            with uu.RecordStdout() as record:
+                with lutils.OutputMultiplexerContext(
+                        lutils.OutputMultiplexer.Destination.FILENAMES |
+                        lutils.OutputMultiplexer.Destination.FILEHANDLES |
+                        lutils.OutputMultiplexer.Destination.LOG_INFO,
+                        filenames = [filename, '/dev/null'],
+                        handles = [tmpfile1, sys.stdout],
+                ) as mplex:
+                    mplex.print(secret_message, end='')
+
+                # Make sure it was written to the filename.
+                with open(filename, 'r') as rf:
+                    self.assertEqual(rf.readline(), secret_message)
+                os.unlink(filename)
+
+            # Make sure it was written to stdout.
+            tmp = record().readline()
+            self.assertEqual(tmp, secret_message)
+
+            # Make sure it was written to the filehandle.
             tmpfile1.seek(0)
             tmp = tmpfile1.readline()
             self.assertEqual(tmp, secret_message)
 
+    def test_record_streams(self):
+        with uu.RecordMultipleStreams(sys.stderr, sys.stdout) as record:
+            print("This is a test!")
+            print("This is one too.", file=sys.stderr)
+        self.assertEqual(record().readlines(),
+                         ["This is a test!\n", "This is one too.\n"])
+
 
 if __name__ == '__main__':
-    unittest.main = bootstrap.initialize(unittest.main)
     unittest.main()