Adds a trie class.
authorScott Gasch <[email protected]>
Fri, 24 Sep 2021 05:10:30 +0000 (22:10 -0700)
committerScott Gasch <[email protected]>
Fri, 24 Sep 2021 05:10:30 +0000 (22:10 -0700)
trie.py [new file with mode: 0644]

diff --git a/trie.py b/trie.py
new file mode 100644 (file)
index 0000000..be7a48d
--- /dev/null
+++ b/trie.py
@@ -0,0 +1,246 @@
+#!/usr/bin/env python3
+
+from typing import Any, Sequence
+
+
+class Trie(object):
+    """
+    This is a Trie class, see: https://en.wikipedia.org/wiki/Trie.
+
+    It attempts to follow Pythonic container patterns.  See doctests
+    for examples.
+
+    """
+    def __init__(self):
+        self.root = {}
+        self.end = "~#END#~"
+        self.length = 0
+
+    def insert(self, item: Sequence[Any]):
+        """
+        Insert an item.
+
+        >>> t = Trie()
+        >>> t.insert('test')
+        >>> t.__contains__('test')
+        True
+
+        """
+        current = self.root
+        for child in item:
+            current = current.setdefault(child, {})
+        current[self.end] = self.end
+        self.length += 1
+
+    def __contains__(self, item: Sequence[Any]) -> bool:
+        """
+        Check whether an item is in the Trie.
+
+        >>> t = Trie()
+        >>> t.insert('test')
+        >>> t.__contains__('test')
+        True
+        >>> t.__contains__('testing')
+        False
+        >>> 'test' in t
+        True
+
+        """
+        current = self.__traverse__(item)
+        if current is None:
+            return False
+        else:
+            return self.end in current
+
+    def contains_prefix(self, item: Sequence[Any]):
+        """
+        Check whether a prefix is in the Trie.  The prefix may or may not
+        be a full item.
+
+        >>> t = Trie()
+        >>> t.insert('testicle')
+        >>> t.contains_prefix('test')
+        True
+        >>> t.contains_prefix('testicle')
+        True
+        >>> t.contains_prefix('tessel')
+        False
+
+        """
+        current = self.__traverse__(item)
+        return current is not None
+
+    def __traverse__(self, item: Sequence[Any]):
+        current = self.root
+        for child in item:
+            if child in current:
+                current = current[child]
+            else:
+                return None
+        return current
+
+    def __getitem__(self, item: Sequence[Any]):
+        """Given an item, return its Trie node which contains all
+        of the successor (child) node pointers.  If the item is not
+        a node in the Trie, raise a KeyError.
+
+        >>> t = Trie()
+        >>> t.insert('test')
+        >>> t.insert('testicle')
+        >>> t.insert('tessera')
+        >>> t.insert('tesack')
+        >>> t['tes']
+        {'t': {'~#END#~': '~#END#~', 'i': {'c': {'l': {'e': {'~#END#~': '~#END#~'}}}}}, 's': {'e': {'r': {'a': {'~#END#~': '~#END#~'}}}}, 'a': {'c': {'k': {'~#END#~': '~#END#~'}}}}
+
+        """
+        ret = self.__traverse__(item)
+        if ret is None:
+            raise KeyError(f"Node '{item}' is not in the trie")
+        return ret
+
+    def delete_recursively(self, node, item: Sequence[Any]) -> bool:
+        if len(item) == 1:
+            del node[item]
+            if len(node) == 0 and node is not self.root:
+                del node
+                return True
+            else:
+                return False
+        else:
+            car = item[0]
+            cdr = item[1:]
+            lower = node[car]
+            if self.delete_recursively(lower, cdr):
+                return self.delete_recursively(node, car)
+            return False
+
+    def __delitem__(self, item: Sequence[Any]):
+        """
+        Delete an item from the Trie.
+
+        >>> t = Trie()
+        >>> t.insert('test')
+        >>> t.insert('tess')
+        >>> t.insert('tessel')
+        >>> len(t)
+        3
+        >>> t.root
+        {'t': {'e': {'s': {'t': {'~#END#~': '~#END#~'}, 's': {'~#END#~': '~#END#~', 'e': {'l': {'~#END#~': '~#END#~'}}}}}}}
+        >>> t.__delitem__('test')
+        >>> len(t)
+        2
+        >>> t.root
+        {'t': {'e': {'s': {'s': {'~#END#~': '~#END#~', 'e': {'l': {'~#END#~': '~#END#~'}}}}}}}
+        >>> for x in t:
+        ...     print(x)
+        tess
+        tessel
+        >>> t.__delitem__('tessel')
+        >>> len(t)
+        1
+        >>> t.root
+        {'t': {'e': {'s': {'s': {'~#END#~': '~#END#~'}}}}}
+        >>> for x in t:
+        ...     print(x)
+        tess
+        >>> t.__delitem__('tess')
+        >>> len(t)
+        0
+        >>> t.root
+        {}
+        >>> t.insert('testy')
+        >>> len(t)
+        1
+
+        """
+        if item not in self:
+            raise KeyError(f"Node '{item}' is not in the trie")
+        self.delete_recursively(self.root, item)
+        self.length -= 1
+
+    def __len__(self):
+        """
+        Returns a count of the Trie's item population.
+
+        >>> t = Trie()
+        >>> len(t)
+        0
+        >>> t.insert('test')
+        >>> len(t)
+        1
+        >>> t.insert('testicle')
+        >>> len(t)
+        2
+
+        """
+        return self.length
+
+    def __iter__(self):
+        self.content_generator = self.generate_recursively(self.root, '')
+        return self
+
+    def generate_recursively(self, node, path: Sequence[Any]):
+        """
+        Generate items in the trie one by one.
+
+        >>> t = Trie()
+        >>> t.insert('test')
+        >>> t.insert('tickle')
+        >>> for item in t.generate_recursively(t.root, ''):
+        ...     print(item)
+        test
+        tickle
+
+        """
+        for child in node:
+            if child == self.end:
+                yield path
+            else:
+                yield from self.generate_recursively(node[child], path + child)
+
+    def __next__(self):
+        """
+        Iterate through the contents of the trie.
+
+        >>> t = Trie()
+        >>> t.insert('test')
+        >>> t.insert('tickle')
+        >>> for item in t:
+        ...     print(item)
+        test
+        tickle
+
+        """
+        ret = next(self.content_generator)
+        if ret is not None:
+            return ret
+        raise StopIteration
+
+    def successors(self, item: Sequence[Any]):
+        """
+        Return a list of the successors of an item.
+
+        >>> t = Trie()
+        >>> t.insert('what')
+        >>> t.insert('who')
+        >>> t.insert('when')
+        >>> t.successors('wh')
+        ['a', 'o', 'e']
+
+        >>> u = Trie()
+        >>> u.insert(['this', 'is', 'a', 'test'])
+        >>> u.insert(['this', 'is', 'a', 'robbery'])
+        >>> u.insert(['this', 'is', 'a', 'walrus'])
+        >>> u.successors(['this', 'is', 'a'])
+        ['test', 'robbery', 'walrus']
+
+        """
+        node = self.__traverse__(item)
+        if node is None:
+            return None
+        return [x for x in node if x != self.end]
+
+
+if __name__ == '__main__':
+    import doctest
+    doctest.testmod()