be7a48d739b78a9bb47257b0c6670178ce0539fd
[python_utils.git] / trie.py
1 #!/usr/bin/env python3
2
3 from typing import Any, Sequence
4
5
6 class Trie(object):
7     """
8     This is a Trie class, see: https://en.wikipedia.org/wiki/Trie.
9
10     It attempts to follow Pythonic container patterns.  See doctests
11     for examples.
12
13     """
14     def __init__(self):
15         self.root = {}
16         self.end = "~#END#~"
17         self.length = 0
18
19     def insert(self, item: Sequence[Any]):
20         """
21         Insert an item.
22
23         >>> t = Trie()
24         >>> t.insert('test')
25         >>> t.__contains__('test')
26         True
27
28         """
29         current = self.root
30         for child in item:
31             current = current.setdefault(child, {})
32         current[self.end] = self.end
33         self.length += 1
34
35     def __contains__(self, item: Sequence[Any]) -> bool:
36         """
37         Check whether an item is in the Trie.
38
39         >>> t = Trie()
40         >>> t.insert('test')
41         >>> t.__contains__('test')
42         True
43         >>> t.__contains__('testing')
44         False
45         >>> 'test' in t
46         True
47
48         """
49         current = self.__traverse__(item)
50         if current is None:
51             return False
52         else:
53             return self.end in current
54
55     def contains_prefix(self, item: Sequence[Any]):
56         """
57         Check whether a prefix is in the Trie.  The prefix may or may not
58         be a full item.
59
60         >>> t = Trie()
61         >>> t.insert('testicle')
62         >>> t.contains_prefix('test')
63         True
64         >>> t.contains_prefix('testicle')
65         True
66         >>> t.contains_prefix('tessel')
67         False
68
69         """
70         current = self.__traverse__(item)
71         return current is not None
72
73     def __traverse__(self, item: Sequence[Any]):
74         current = self.root
75         for child in item:
76             if child in current:
77                 current = current[child]
78             else:
79                 return None
80         return current
81
82     def __getitem__(self, item: Sequence[Any]):
83         """Given an item, return its Trie node which contains all
84         of the successor (child) node pointers.  If the item is not
85         a node in the Trie, raise a KeyError.
86
87         >>> t = Trie()
88         >>> t.insert('test')
89         >>> t.insert('testicle')
90         >>> t.insert('tessera')
91         >>> t.insert('tesack')
92         >>> t['tes']
93         {'t': {'~#END#~': '~#END#~', 'i': {'c': {'l': {'e': {'~#END#~': '~#END#~'}}}}}, 's': {'e': {'r': {'a': {'~#END#~': '~#END#~'}}}}, 'a': {'c': {'k': {'~#END#~': '~#END#~'}}}}
94
95         """
96         ret = self.__traverse__(item)
97         if ret is None:
98             raise KeyError(f"Node '{item}' is not in the trie")
99         return ret
100
101     def delete_recursively(self, node, item: Sequence[Any]) -> bool:
102         if len(item) == 1:
103             del node[item]
104             if len(node) == 0 and node is not self.root:
105                 del node
106                 return True
107             else:
108                 return False
109         else:
110             car = item[0]
111             cdr = item[1:]
112             lower = node[car]
113             if self.delete_recursively(lower, cdr):
114                 return self.delete_recursively(node, car)
115             return False
116
117     def __delitem__(self, item: Sequence[Any]):
118         """
119         Delete an item from the Trie.
120
121         >>> t = Trie()
122         >>> t.insert('test')
123         >>> t.insert('tess')
124         >>> t.insert('tessel')
125         >>> len(t)
126         3
127         >>> t.root
128         {'t': {'e': {'s': {'t': {'~#END#~': '~#END#~'}, 's': {'~#END#~': '~#END#~', 'e': {'l': {'~#END#~': '~#END#~'}}}}}}}
129         >>> t.__delitem__('test')
130         >>> len(t)
131         2
132         >>> t.root
133         {'t': {'e': {'s': {'s': {'~#END#~': '~#END#~', 'e': {'l': {'~#END#~': '~#END#~'}}}}}}}
134         >>> for x in t:
135         ...     print(x)
136         tess
137         tessel
138         >>> t.__delitem__('tessel')
139         >>> len(t)
140         1
141         >>> t.root
142         {'t': {'e': {'s': {'s': {'~#END#~': '~#END#~'}}}}}
143         >>> for x in t:
144         ...     print(x)
145         tess
146         >>> t.__delitem__('tess')
147         >>> len(t)
148         0
149         >>> t.root
150         {}
151         >>> t.insert('testy')
152         >>> len(t)
153         1
154
155         """
156         if item not in self:
157             raise KeyError(f"Node '{item}' is not in the trie")
158         self.delete_recursively(self.root, item)
159         self.length -= 1
160
161     def __len__(self):
162         """
163         Returns a count of the Trie's item population.
164
165         >>> t = Trie()
166         >>> len(t)
167         0
168         >>> t.insert('test')
169         >>> len(t)
170         1
171         >>> t.insert('testicle')
172         >>> len(t)
173         2
174
175         """
176         return self.length
177
178     def __iter__(self):
179         self.content_generator = self.generate_recursively(self.root, '')
180         return self
181
182     def generate_recursively(self, node, path: Sequence[Any]):
183         """
184         Generate items in the trie one by one.
185
186         >>> t = Trie()
187         >>> t.insert('test')
188         >>> t.insert('tickle')
189         >>> for item in t.generate_recursively(t.root, ''):
190         ...     print(item)
191         test
192         tickle
193
194         """
195         for child in node:
196             if child == self.end:
197                 yield path
198             else:
199                 yield from self.generate_recursively(node[child], path + child)
200
201     def __next__(self):
202         """
203         Iterate through the contents of the trie.
204
205         >>> t = Trie()
206         >>> t.insert('test')
207         >>> t.insert('tickle')
208         >>> for item in t:
209         ...     print(item)
210         test
211         tickle
212
213         """
214         ret = next(self.content_generator)
215         if ret is not None:
216             return ret
217         raise StopIteration
218
219     def successors(self, item: Sequence[Any]):
220         """
221         Return a list of the successors of an item.
222
223         >>> t = Trie()
224         >>> t.insert('what')
225         >>> t.insert('who')
226         >>> t.insert('when')
227         >>> t.successors('wh')
228         ['a', 'o', 'e']
229
230         >>> u = Trie()
231         >>> u.insert(['this', 'is', 'a', 'test'])
232         >>> u.insert(['this', 'is', 'a', 'robbery'])
233         >>> u.insert(['this', 'is', 'a', 'walrus'])
234         >>> u.successors(['this', 'is', 'a'])
235         ['test', 'robbery', 'walrus']
236
237         """
238         node = self.__traverse__(item)
239         if node is None:
240             return None
241         return [x for x in node if x != self.end]
242
243
244 if __name__ == '__main__':
245     import doctest
246     doctest.testmod()