From: Scott Gasch Date: Fri, 24 Sep 2021 05:10:30 +0000 (-0700) Subject: Adds a trie class. X-Git-Url: https://wannabe.guru.org/gitweb/?a=commitdiff_plain;h=d82d0db238c72a4c6ab9403277d5092d3f9793d3;p=python_utils.git Adds a trie class. --- diff --git a/trie.py b/trie.py new file mode 100644 index 0000000..be7a48d --- /dev/null +++ b/trie.py @@ -0,0 +1,246 @@ +#!/usr/bin/env python3 + +from typing import Any, Sequence + + +class Trie(object): + """ + This is a Trie class, see: https://en.wikipedia.org/wiki/Trie. + + It attempts to follow Pythonic container patterns. See doctests + for examples. + + """ + def __init__(self): + self.root = {} + self.end = "~#END#~" + self.length = 0 + + def insert(self, item: Sequence[Any]): + """ + Insert an item. + + >>> t = Trie() + >>> t.insert('test') + >>> t.__contains__('test') + True + + """ + current = self.root + for child in item: + current = current.setdefault(child, {}) + current[self.end] = self.end + self.length += 1 + + def __contains__(self, item: Sequence[Any]) -> bool: + """ + Check whether an item is in the Trie. + + >>> t = Trie() + >>> t.insert('test') + >>> t.__contains__('test') + True + >>> t.__contains__('testing') + False + >>> 'test' in t + True + + """ + current = self.__traverse__(item) + if current is None: + return False + else: + return self.end in current + + def contains_prefix(self, item: Sequence[Any]): + """ + Check whether a prefix is in the Trie. The prefix may or may not + be a full item. + + >>> t = Trie() + >>> t.insert('testicle') + >>> t.contains_prefix('test') + True + >>> t.contains_prefix('testicle') + True + >>> t.contains_prefix('tessel') + False + + """ + current = self.__traverse__(item) + return current is not None + + def __traverse__(self, item: Sequence[Any]): + current = self.root + for child in item: + if child in current: + current = current[child] + else: + return None + return current + + def __getitem__(self, item: Sequence[Any]): + """Given an item, return its Trie node which contains all + of the successor (child) node pointers. If the item is not + a node in the Trie, raise a KeyError. + + >>> t = Trie() + >>> t.insert('test') + >>> t.insert('testicle') + >>> t.insert('tessera') + >>> t.insert('tesack') + >>> t['tes'] + {'t': {'~#END#~': '~#END#~', 'i': {'c': {'l': {'e': {'~#END#~': '~#END#~'}}}}}, 's': {'e': {'r': {'a': {'~#END#~': '~#END#~'}}}}, 'a': {'c': {'k': {'~#END#~': '~#END#~'}}}} + + """ + ret = self.__traverse__(item) + if ret is None: + raise KeyError(f"Node '{item}' is not in the trie") + return ret + + def delete_recursively(self, node, item: Sequence[Any]) -> bool: + if len(item) == 1: + del node[item] + if len(node) == 0 and node is not self.root: + del node + return True + else: + return False + else: + car = item[0] + cdr = item[1:] + lower = node[car] + if self.delete_recursively(lower, cdr): + return self.delete_recursively(node, car) + return False + + def __delitem__(self, item: Sequence[Any]): + """ + Delete an item from the Trie. + + >>> t = Trie() + >>> t.insert('test') + >>> t.insert('tess') + >>> t.insert('tessel') + >>> len(t) + 3 + >>> t.root + {'t': {'e': {'s': {'t': {'~#END#~': '~#END#~'}, 's': {'~#END#~': '~#END#~', 'e': {'l': {'~#END#~': '~#END#~'}}}}}}} + >>> t.__delitem__('test') + >>> len(t) + 2 + >>> t.root + {'t': {'e': {'s': {'s': {'~#END#~': '~#END#~', 'e': {'l': {'~#END#~': '~#END#~'}}}}}}} + >>> for x in t: + ... print(x) + tess + tessel + >>> t.__delitem__('tessel') + >>> len(t) + 1 + >>> t.root + {'t': {'e': {'s': {'s': {'~#END#~': '~#END#~'}}}}} + >>> for x in t: + ... print(x) + tess + >>> t.__delitem__('tess') + >>> len(t) + 0 + >>> t.root + {} + >>> t.insert('testy') + >>> len(t) + 1 + + """ + if item not in self: + raise KeyError(f"Node '{item}' is not in the trie") + self.delete_recursively(self.root, item) + self.length -= 1 + + def __len__(self): + """ + Returns a count of the Trie's item population. + + >>> t = Trie() + >>> len(t) + 0 + >>> t.insert('test') + >>> len(t) + 1 + >>> t.insert('testicle') + >>> len(t) + 2 + + """ + return self.length + + def __iter__(self): + self.content_generator = self.generate_recursively(self.root, '') + return self + + def generate_recursively(self, node, path: Sequence[Any]): + """ + Generate items in the trie one by one. + + >>> t = Trie() + >>> t.insert('test') + >>> t.insert('tickle') + >>> for item in t.generate_recursively(t.root, ''): + ... print(item) + test + tickle + + """ + for child in node: + if child == self.end: + yield path + else: + yield from self.generate_recursively(node[child], path + child) + + def __next__(self): + """ + Iterate through the contents of the trie. + + >>> t = Trie() + >>> t.insert('test') + >>> t.insert('tickle') + >>> for item in t: + ... print(item) + test + tickle + + """ + ret = next(self.content_generator) + if ret is not None: + return ret + raise StopIteration + + def successors(self, item: Sequence[Any]): + """ + Return a list of the successors of an item. + + >>> t = Trie() + >>> t.insert('what') + >>> t.insert('who') + >>> t.insert('when') + >>> t.successors('wh') + ['a', 'o', 'e'] + + >>> u = Trie() + >>> u.insert(['this', 'is', 'a', 'test']) + >>> u.insert(['this', 'is', 'a', 'robbery']) + >>> u.insert(['this', 'is', 'a', 'walrus']) + >>> u.successors(['this', 'is', 'a']) + ['test', 'robbery', 'walrus'] + + """ + node = self.__traverse__(item) + if node is None: + return None + return [x for x in node if x != self.end] + + +if __name__ == '__main__': + import doctest + doctest.testmod()