Change settings in flake8 and black.
[python_utils.git] / collect / trie.py
1 #!/usr/bin/env python3
2
3 from typing import Any, Sequence
4
5
6 class Trie(object):
7     """
8     This is a Trie class, see: https://en.wikipedia.org/wiki/Trie.
9
10     It attempts to follow Pythonic container patterns.  See doctests
11     for examples.
12
13     """
14
15     def __init__(self):
16         self.root = {}
17         self.end = "~END~"
18         self.length = 0
19         self.viz = ''
20
21     def insert(self, item: Sequence[Any]):
22         """
23         Insert an item.
24
25         >>> t = Trie()
26         >>> t.insert('test')
27         >>> t.__contains__('test')
28         True
29
30         """
31         current = self.root
32         for child in item:
33             current = current.setdefault(child, {})
34         current[self.end] = self.end
35         self.length += 1
36
37     def __contains__(self, item: Sequence[Any]) -> bool:
38         """
39         Check whether an item is in the Trie.
40
41         >>> t = Trie()
42         >>> t.insert('test')
43         >>> t.__contains__('test')
44         True
45         >>> t.__contains__('testing')
46         False
47         >>> 'test' in t
48         True
49
50         """
51         current = self.__traverse__(item)
52         if current is None:
53             return False
54         else:
55             return self.end in current
56
57     def contains_prefix(self, item: Sequence[Any]):
58         """
59         Check whether a prefix is in the Trie.  The prefix may or may not
60         be a full item.
61
62         >>> t = Trie()
63         >>> t.insert('testicle')
64         >>> t.contains_prefix('test')
65         True
66         >>> t.contains_prefix('testicle')
67         True
68         >>> t.contains_prefix('tessel')
69         False
70
71         """
72         current = self.__traverse__(item)
73         return current is not None
74
75     def __traverse__(self, item: Sequence[Any]):
76         current = self.root
77         for child in item:
78             if child in current:
79                 current = current[child]
80             else:
81                 return None
82         return current
83
84     def __getitem__(self, item: Sequence[Any]):
85         """Given an item, return its Trie node which contains all
86         of the successor (child) node pointers.  If the item is not
87         a node in the Trie, raise a KeyError.
88
89         >>> t = Trie()
90         >>> t.insert('test')
91         >>> t.insert('testicle')
92         >>> t.insert('tessera')
93         >>> t.insert('tesack')
94         >>> t['tes']
95         {'t': {'~END~': '~END~', 'i': {'c': {'l': {'e': {'~END~': '~END~'}}}}}, 's': {'e': {'r': {'a': {'~END~': '~END~'}}}}, 'a': {'c': {'k': {'~END~': '~END~'}}}}
96
97         """
98         ret = self.__traverse__(item)
99         if ret is None:
100             raise KeyError(f"Node '{item}' is not in the trie")
101         return ret
102
103     def delete_recursively(self, node, item: Sequence[Any]) -> bool:
104         if len(item) == 1:
105             del node[item]
106             if len(node) == 0 and node is not self.root:
107                 del node
108                 return True
109             else:
110                 return False
111         else:
112             car = item[0]
113             cdr = item[1:]
114             lower = node[car]
115             if self.delete_recursively(lower, cdr):
116                 return self.delete_recursively(node, car)
117             return False
118
119     def __delitem__(self, item: Sequence[Any]):
120         """
121         Delete an item from the Trie.
122
123         >>> t = Trie()
124         >>> t.insert('test')
125         >>> t.insert('tess')
126         >>> t.insert('tessel')
127         >>> len(t)
128         3
129         >>> t.root
130         {'t': {'e': {'s': {'t': {'~END~': '~END~'}, 's': {'~END~': '~END~', 'e': {'l': {'~END~': '~END~'}}}}}}}
131         >>> t.__delitem__('test')
132         >>> len(t)
133         2
134         >>> t.root
135         {'t': {'e': {'s': {'s': {'~END~': '~END~', 'e': {'l': {'~END~': '~END~'}}}}}}}
136         >>> for x in t:
137         ...     print(x)
138         tess
139         tessel
140         >>> t.__delitem__('tessel')
141         >>> len(t)
142         1
143         >>> t.root
144         {'t': {'e': {'s': {'s': {'~END~': '~END~'}}}}}
145         >>> for x in t:
146         ...     print(x)
147         tess
148         >>> t.__delitem__('tess')
149         >>> len(t)
150         0
151         >>> t.root
152         {}
153         >>> t.insert('testy')
154         >>> len(t)
155         1
156
157         """
158         if item not in self:
159             raise KeyError(f"Node '{item}' is not in the trie")
160         self.delete_recursively(self.root, item)
161         self.length -= 1
162
163     def __len__(self):
164         """
165         Returns a count of the Trie's item population.
166
167         >>> t = Trie()
168         >>> len(t)
169         0
170         >>> t.insert('test')
171         >>> len(t)
172         1
173         >>> t.insert('testicle')
174         >>> len(t)
175         2
176
177         """
178         return self.length
179
180     def __iter__(self):
181         self.content_generator = self.generate_recursively(self.root, '')
182         return self
183
184     def generate_recursively(self, node, path: Sequence[Any]):
185         """
186         Generate items in the trie one by one.
187
188         >>> t = Trie()
189         >>> t.insert('test')
190         >>> t.insert('tickle')
191         >>> for item in t.generate_recursively(t.root, ''):
192         ...     print(item)
193         test
194         tickle
195
196         """
197         for child in node:
198             if child == self.end:
199                 yield path
200             else:
201                 yield from self.generate_recursively(node[child], path + child)
202
203     def __next__(self):
204         """
205         Iterate through the contents of the trie.
206
207         >>> t = Trie()
208         >>> t.insert('test')
209         >>> t.insert('tickle')
210         >>> for item in t:
211         ...     print(item)
212         test
213         tickle
214
215         """
216         ret = next(self.content_generator)
217         if ret is not None:
218             return ret
219         raise StopIteration
220
221     def successors(self, item: Sequence[Any]):
222         """
223         Return a list of the successors of an item.
224
225         >>> t = Trie()
226         >>> t.insert('what')
227         >>> t.insert('who')
228         >>> t.insert('when')
229         >>> t.successors('wh')
230         ['a', 'o', 'e']
231
232         >>> u = Trie()
233         >>> u.insert(['this', 'is', 'a', 'test'])
234         >>> u.insert(['this', 'is', 'a', 'robbery'])
235         >>> u.insert(['this', 'is', 'a', 'walrus'])
236         >>> u.successors(['this', 'is', 'a'])
237         ['test', 'robbery', 'walrus']
238
239         """
240         node = self.__traverse__(item)
241         if node is None:
242             return None
243         return [x for x in node if x != self.end]
244
245     def repr_fancy(
246         self,
247         padding: str,
248         pointer: str,
249         parent: str,
250         node: Any,
251         has_sibling: bool,
252     ):
253         if node is None:
254             return
255         if node is not self.root:
256             ret = f'\n{padding}{pointer}'
257             if has_sibling:
258                 padding += '│  '
259             else:
260                 padding += '   '
261         else:
262             ret = f'{pointer}'
263
264         child_count = 0
265         for child in node:
266             if child != self.end:
267                 child_count += 1
268
269         for child in node:
270             if child != self.end:
271                 if child_count > 1:
272                     pointer = "├──"
273                     has_sibling = True
274                 else:
275                     pointer = "└──"
276                     has_sibling = False
277                 pointer += f'{child}'
278                 child_count -= 1
279                 ret += self.repr_fancy(padding, pointer, node, node[child], has_sibling)
280         return ret
281
282     def repr_brief(self, node, delimiter):
283         """
284         A friendly string representation of the contents of the Trie.
285
286         >>> t = Trie()
287         >>> t.insert([10, 0, 0, 1])
288         >>> t.insert([10, 0, 0, 2])
289         >>> t.insert([10, 10, 10, 1])
290         >>> t.insert([10, 10, 10, 2])
291         >>> t.repr_brief(t.root, '.')
292         '10.[0.0.[1, 2], 10.10.[1, 2]]'
293
294         """
295         child_count = 0
296         my_rep = ''
297         for child in node:
298             if child != self.end:
299                 child_count += 1
300                 child_rep = self.repr_brief(node[child], delimiter)
301                 if len(child_rep) > 0:
302                     my_rep += str(child) + delimiter + child_rep + ", "
303                 else:
304                     my_rep += str(child) + ", "
305         if len(my_rep) > 1:
306             my_rep = my_rep[:-2]
307         if child_count > 1:
308             my_rep = f'[{my_rep}]'
309         return my_rep
310
311     def __repr__(self):
312         """
313         A friendly string representation of the contents of the Trie.  Under
314         the covers uses repr_fancy.
315
316         >>> t = Trie()
317         >>> t.insert([10, 0, 0, 1])
318         >>> t.insert([10, 0, 0, 2])
319         >>> t.insert([10, 10, 10, 1])
320         >>> t.insert([10, 10, 10, 2])
321         >>> print(t)
322         *
323         └──10
324            ├──0
325            │  └──0
326            │     ├──1
327            │     └──2
328            └──10
329               └──10
330                  ├──1
331                  └──2
332
333         """
334         return self.repr_fancy('', '*', self.root, self.root, False)
335
336
337 if __name__ == '__main__':
338     import doctest
339
340     doctest.testmod()