Make persistent use tighter permissions by default.
authorScott Gasch <[email protected]>
Mon, 6 Mar 2023 17:40:40 +0000 (09:40 -0800)
committerScott Gasch <[email protected]>
Mon, 6 Mar 2023 17:40:40 +0000 (09:40 -0800)
src/pyutils/files/file_utils.py
src/pyutils/persistent.py

index c98ec8ce40cef16bc76b83179822559e08176ddf..d05cae6b592110847fc73ee9d7e2499ff80c9481 100644 (file)
@@ -21,7 +21,7 @@ import pathlib
 import re
 import time
 from os.path import exists, isfile, join
-from typing import Callable, List, Literal, Optional, TextIO
+from typing import IO, Any, Callable, List, Literal, Optional
 from uuid import uuid4
 
 logger = logging.getLogger(__name__)
@@ -1205,11 +1205,12 @@ class FileWriter(contextlib.AbstractContextManager):
         self.filename = filename
         uuid = uuid4()
         self.tempfile = f"{filename}-{uuid}.tmp"
-        self.handle: Optional[TextIO] = None
+        self.handle: Optional[IO[Any]] = None
 
-    def __enter__(self) -> TextIO:
+    def __enter__(self) -> IO[Any]:
         assert not does_path_exist(self.tempfile)
         self.handle = open(self.tempfile, mode="w")
+        assert self.handle
         return self.handle
 
     def __exit__(self, exc_type, exc_val, exc_tb) -> Literal[False]:
@@ -1225,19 +1226,20 @@ class FileWriter(contextlib.AbstractContextManager):
 class CreateFileWithMode(contextlib.AbstractContextManager):
     """This helper context manager can be used instead of the typical
     pattern for creating a file if you want to ensure that the file
-    created is a particular permission mode upon creation.
+    created is a particular filesystem permission mode upon creation.
 
     Python's open doesn't support this; you need to set the os.umask
     and then create a descriptor to open via os.open, see below.
 
         >>> import os
         >>> filename = f'/tmp/CreateFileWithModeTest.{os.getpid()}'
-        >>> with CreateFileWithMode(filename, mode=0o600) as wf:
+        >>> with CreateFileWithMode(filename, filesystem_mode=0o600) as wf:
         ...     print('This is a test', file=wf)
         >>> result = os.stat(filename)
 
-        Note: there is a high order bit set in this that is S_IFREG indicating
-        that the file is a "normal file".  Clear it with the mask.
+        Note: there is a high order bit set in this that is S_IFREG
+        indicating that the file is a "normal file".  Clear it with
+        the mask.
 
         >>> print(f'{result.st_mode & 0o7777:o}')
         600
@@ -1246,14 +1248,22 @@ class CreateFileWithMode(contextlib.AbstractContextManager):
         >>> contents
         'This is a test\\n'
         >>> remove(filename)
+
     """
 
-    def __init__(self, filename: str, mode: Optional[int] = 0o600) -> None:
+    def __init__(
+        self,
+        filename: str,
+        filesystem_mode: Optional[int] = 0o600,
+        open_mode: Optional[str] = "w",
+    ) -> None:
         """
         Args:
             filename: path of the file to create.
-            mode: the UNIX-style octal mode with which to create the
-                filename.  Defaults to 0o600.
+            filesystem_mode: the UNIX-style octal mode with which to create
+                the filename.  Defaults to 0o600.
+            open_mode: the mode to use when opening the file (e.g. 'w', 'wb',
+                etc...)
 
         .. warning::
 
@@ -1261,20 +1271,25 @@ class CreateFileWithMode(contextlib.AbstractContextManager):
 
         """
         self.filename = filename
-        if mode is not None:
-            self.mode = mode & 0o7777
+        if filesystem_mode is not None:
+            self.filesystem_mode = filesystem_mode & 0o7777
+        else:
+            self.filesystem_mode = 0o666
+        if open_mode is not None:
+            self.open_mode = open_mode
         else:
-            self.mode = 0o666
-        self.handle: Optional[TextIO] = None
+            self.open_mode = "w"
+        self.handle: Optional[IO[Any]] = None
         self.old_umask = os.umask(0)
 
-    def __enter__(self) -> TextIO:
+    def __enter__(self) -> IO[Any]:
         descriptor = os.open(
             path=self.filename,
             flags=(os.O_WRONLY | os.O_CREAT | os.O_TRUNC),
-            mode=self.mode,
+            mode=self.filesystem_mode,
         )
-        self.handle = open(descriptor, "w")
+        self.handle = open(descriptor, self.open_mode)
+        assert self.handle
         return self.handle
 
     def __exit__(self, exc_type, exc_val, exc_tb) -> Literal[False]:
index c2d530f172c61d1e01a97faa79add88adeaf774f..200d86cf5101ce6f746b4c26ce19b2e398fd8d9c 100644 (file)
@@ -164,6 +164,11 @@ class PicklingFileBasedPersistent(FileBasedPersistent):
 
     """
 
+    @abstractmethod
+    def __init__(self, data: Optional[Any] = None):
+        """You should override this."""
+        pass
+
     @classmethod
     @overrides
     def load(cls) -> Optional[Any]:
@@ -191,7 +196,7 @@ class PicklingFileBasedPersistent(FileBasedPersistent):
             try:
                 import pickle
 
-                with open(filename, "wb") as wf:
+                with file_utils.CreateFileWithMode(filename, 0o600, "wb") as wf:
                     pickle.dump(self.get_persistent_data(), wf, pickle.HIGHEST_PROTOCOL)
                 return True
             except Exception as e:
@@ -241,6 +246,11 @@ class JsonFileBasedPersistent(FileBasedPersistent):
         c = MyClass()
     """
 
+    @abstractmethod
+    def __init__(self, data: Optional[Any]):
+        """You should override this."""
+        pass
+
     @classmethod
     @overrides
     def load(cls) -> Any: