More cleanup, yey!
[python_utils.git] / unscrambler.py
index 78c1f9b4f7e7c1a5ef618b54eeed2eb012db7a4d..1b242309b649eaa036277fdb22fc6f9c7705f0c8 100644 (file)
@@ -1,15 +1,20 @@
 #!/usr/bin/env python3
 
+"""A fast word unscrambler library."""
+
 import logging
-from typing import Dict, Mapping
+from typing import Dict, Mapping, Optional
 
 import config
 import decorator_utils
+import file_utils
 import list_utils
 
-cfg = config.add_commandline_args(f'Unscramble! ({__file__})', 'A fast word unscrambler.')
+cfg = config.add_commandline_args(
+    f'Unscrambler base library ({__file__})', 'A fast word unscrambler.'
+)
 cfg.add_argument(
-    "--unscramble_indexfile",
+    "--unscrambler_default_indexfile",
     help="Path to a file of signature -> word index.",
     metavar="FILENAME",
     default="/usr/share/dict/sparse_index",
@@ -18,10 +23,10 @@ cfg.add_argument(
 logger = logging.getLogger(__name__)
 
 letters_bits = 32
-letters_mask = 2 ** letters_bits - 1
+letters_mask = 2**letters_bits - 1
 
 fprint_bits = 52
-fprint_mask = (2 ** fprint_bits - 1) << letters_bits
+fprint_mask = (2**fprint_bits - 1) << letters_bits
 
 fprint_feature_bit = {
     'e': 0,
@@ -98,35 +103,40 @@ class Unscrambler(object):
 
     """
 
-    def __init__(self):
+    def __init__(self, indexfile: Optional[str] = None):
         # Cached index per instance.
         self.sigs = []
         self.words = []
 
-        if 'unscramble_indexfile' in config.config:
-            indexfile = config.config['unscramble_indexfile']
-        else:
-            indexfile = "/usr/share/dict/sparse_index"
-
-        with open(indexfile, 'r') as rf:
+        filename = Unscrambler.get_indexfile(indexfile)
+        with open(filename, 'r') as rf:
             lines = rf.readlines()
         for line in lines:
             line = line[:-1]
             (fsig, word) = line.split('+')
-            fsig = int(fsig, 16)
-            self.sigs.append(fsig)
+            isig = int(fsig, 16)
+            self.sigs.append(isig)
             self.words.append(word)
 
+    @staticmethod
+    def get_indexfile(indexfile: Optional[str]) -> str:
+        if indexfile is None:
+            if 'unscrambler_default_indexfile' in config.config:
+                indexfile = config.config['unscramble_indexfile']
+            else:
+                indexfile = "/usr/share/dict/sparse_index"
+        else:
+            assert file_utils.file_is_readable(indexfile), f"Can't read {indexfile}"
+        return indexfile
+
     # 52 bits
     @staticmethod
-    def _compute_word_fingerprint(word: str, population: Mapping[str, int]) -> int:
+    def _compute_word_fingerprint(population: Mapping[str, int]) -> int:
         fp = 0
         for pair in sorted(population.items(), key=lambda x: x[1], reverse=True):
             letter = pair[0]
             if letter in fprint_feature_bit:
-                count = pair[1]
-                if count > 3:
-                    count = 3
+                count = min(pair[1], 3)
                 shift = fprint_feature_bit[letter]
                 s = count << shift
                 fp |= s
@@ -135,25 +145,23 @@ class Unscrambler(object):
     # 32 bits
     @staticmethod
     def _compute_word_letter_sig(
-        letter_sigs: Mapping[str, int],
+        lsigs: Mapping[str, int],
         word: str,
         population: Mapping[str, int],
     ) -> int:
         sig = 0
         for pair in sorted(population.items(), key=lambda x: x[1], reverse=True):
             letter = pair[0]
-            if letter not in letter_sigs:
+            if letter not in lsigs:
                 continue
-            s = letter_sigs[letter]
+            s = lsigs[letter]
             count = pair[1]
             if count > 1:
                 s <<= count
                 s |= count
             s &= letters_mask
             sig ^= s
-        length = len(word)
-        if length > 31:
-            length = 31
+        length = min(len(word), 31)
         sig ^= length << 8
         sig &= letters_mask
         return sig
@@ -180,7 +188,7 @@ class Unscrambler(object):
 
         """
         population = list_utils.population_counts(word)
-        fprint = Unscrambler._compute_word_fingerprint(word, population)
+        fprint = Unscrambler._compute_word_fingerprint(population)
         letter_sig = Unscrambler._compute_word_letter_sig(letter_sigs, word, population)
         assert fprint & letter_sig == 0
         sig = fprint | letter_sig
@@ -188,7 +196,6 @@ class Unscrambler(object):
 
     @staticmethod
     def repopulate(
-        letter_sigs: Dict[str, int],
         dictfile: str = '/usr/share/dict/words',
         indexfile: str = '/usr/share/dict/sparse_index',
     ) -> None:
@@ -202,13 +209,13 @@ class Unscrambler(object):
             for word in f:
                 word = word.replace('\n', '')
                 word = word.lower()
-                sig = Unscrambler.compute_word_sig(letter_sigs, word)
-                logger.debug("%s => 0x%x" % (word, sig))
+                sig = Unscrambler.compute_word_sig(word)
+                logger.debug("%s => 0x%x", word, sig)
                 if word in seen:
                     continue
                 seen.add(word)
                 if sig in words_by_sigs:
-                    words_by_sigs[sig] += ",%s" % word
+                    words_by_sigs[sig] += f",{word}"
                 else:
                     words_by_sigs[sig] = word
         with open(indexfile, 'w') as f:
@@ -243,13 +250,11 @@ class Unscrambler(object):
 
         """
         ret = {}
-        (exact, location) = list_utils.binary_search(self.sigs, sig)
+        (_, location) = list_utils.binary_search(self.sigs, sig)
         start = location - window_size
-        if start < 0:
-            start = 0
+        start = max(start, 0)
         end = location + 1 + window_size
-        if end > len(self.words):
-            end = len(self.words)
+        end = min(end, len(self.words))
 
         for x in range(start, end):
             word = self.words[x]