changes
[python_utils.git] / collect / trie.py
1 #!/usr/bin/env python3
2
3 from typing import Any, Sequence
4
5
6 class Trie(object):
7     """
8     This is a Trie class, see: https://en.wikipedia.org/wiki/Trie.
9
10     It attempts to follow Pythonic container patterns.  See doctests
11     for examples.
12
13     """
14     def __init__(self):
15         self.root = {}
16         self.end = "~END~"
17         self.length = 0
18         self.viz = ''
19
20     def insert(self, item: Sequence[Any]):
21         """
22         Insert an item.
23
24         >>> t = Trie()
25         >>> t.insert('test')
26         >>> t.__contains__('test')
27         True
28
29         """
30         current = self.root
31         for child in item:
32             current = current.setdefault(child, {})
33         current[self.end] = self.end
34         self.length += 1
35
36     def __contains__(self, item: Sequence[Any]) -> bool:
37         """
38         Check whether an item is in the Trie.
39
40         >>> t = Trie()
41         >>> t.insert('test')
42         >>> t.__contains__('test')
43         True
44         >>> t.__contains__('testing')
45         False
46         >>> 'test' in t
47         True
48
49         """
50         current = self.__traverse__(item)
51         if current is None:
52             return False
53         else:
54             return self.end in current
55
56     def contains_prefix(self, item: Sequence[Any]):
57         """
58         Check whether a prefix is in the Trie.  The prefix may or may not
59         be a full item.
60
61         >>> t = Trie()
62         >>> t.insert('testicle')
63         >>> t.contains_prefix('test')
64         True
65         >>> t.contains_prefix('testicle')
66         True
67         >>> t.contains_prefix('tessel')
68         False
69
70         """
71         current = self.__traverse__(item)
72         return current is not None
73
74     def __traverse__(self, item: Sequence[Any]):
75         current = self.root
76         for child in item:
77             if child in current:
78                 current = current[child]
79             else:
80                 return None
81         return current
82
83     def __getitem__(self, item: Sequence[Any]):
84         """Given an item, return its Trie node which contains all
85         of the successor (child) node pointers.  If the item is not
86         a node in the Trie, raise a KeyError.
87
88         >>> t = Trie()
89         >>> t.insert('test')
90         >>> t.insert('testicle')
91         >>> t.insert('tessera')
92         >>> t.insert('tesack')
93         >>> t['tes']
94         {'t': {'~END~': '~END~', 'i': {'c': {'l': {'e': {'~END~': '~END~'}}}}}, 's': {'e': {'r': {'a': {'~END~': '~END~'}}}}, 'a': {'c': {'k': {'~END~': '~END~'}}}}
95
96         """
97         ret = self.__traverse__(item)
98         if ret is None:
99             raise KeyError(f"Node '{item}' is not in the trie")
100         return ret
101
102     def delete_recursively(self, node, item: Sequence[Any]) -> bool:
103         if len(item) == 1:
104             del node[item]
105             if len(node) == 0 and node is not self.root:
106                 del node
107                 return True
108             else:
109                 return False
110         else:
111             car = item[0]
112             cdr = item[1:]
113             lower = node[car]
114             if self.delete_recursively(lower, cdr):
115                 return self.delete_recursively(node, car)
116             return False
117
118     def __delitem__(self, item: Sequence[Any]):
119         """
120         Delete an item from the Trie.
121
122         >>> t = Trie()
123         >>> t.insert('test')
124         >>> t.insert('tess')
125         >>> t.insert('tessel')
126         >>> len(t)
127         3
128         >>> t.root
129         {'t': {'e': {'s': {'t': {'~END~': '~END~'}, 's': {'~END~': '~END~', 'e': {'l': {'~END~': '~END~'}}}}}}}
130         >>> t.__delitem__('test')
131         >>> len(t)
132         2
133         >>> t.root
134         {'t': {'e': {'s': {'s': {'~END~': '~END~', 'e': {'l': {'~END~': '~END~'}}}}}}}
135         >>> for x in t:
136         ...     print(x)
137         tess
138         tessel
139         >>> t.__delitem__('tessel')
140         >>> len(t)
141         1
142         >>> t.root
143         {'t': {'e': {'s': {'s': {'~END~': '~END~'}}}}}
144         >>> for x in t:
145         ...     print(x)
146         tess
147         >>> t.__delitem__('tess')
148         >>> len(t)
149         0
150         >>> t.root
151         {}
152         >>> t.insert('testy')
153         >>> len(t)
154         1
155
156         """
157         if item not in self:
158             raise KeyError(f"Node '{item}' is not in the trie")
159         self.delete_recursively(self.root, item)
160         self.length -= 1
161
162     def __len__(self):
163         """
164         Returns a count of the Trie's item population.
165
166         >>> t = Trie()
167         >>> len(t)
168         0
169         >>> t.insert('test')
170         >>> len(t)
171         1
172         >>> t.insert('testicle')
173         >>> len(t)
174         2
175
176         """
177         return self.length
178
179     def __iter__(self):
180         self.content_generator = self.generate_recursively(self.root, '')
181         return self
182
183     def generate_recursively(self, node, path: Sequence[Any]):
184         """
185         Generate items in the trie one by one.
186
187         >>> t = Trie()
188         >>> t.insert('test')
189         >>> t.insert('tickle')
190         >>> for item in t.generate_recursively(t.root, ''):
191         ...     print(item)
192         test
193         tickle
194
195         """
196         for child in node:
197             if child == self.end:
198                 yield path
199             else:
200                 yield from self.generate_recursively(node[child], path + child)
201
202     def __next__(self):
203         """
204         Iterate through the contents of the trie.
205
206         >>> t = Trie()
207         >>> t.insert('test')
208         >>> t.insert('tickle')
209         >>> for item in t:
210         ...     print(item)
211         test
212         tickle
213
214         """
215         ret = next(self.content_generator)
216         if ret is not None:
217             return ret
218         raise StopIteration
219
220     def successors(self, item: Sequence[Any]):
221         """
222         Return a list of the successors of an item.
223
224         >>> t = Trie()
225         >>> t.insert('what')
226         >>> t.insert('who')
227         >>> t.insert('when')
228         >>> t.successors('wh')
229         ['a', 'o', 'e']
230
231         >>> u = Trie()
232         >>> u.insert(['this', 'is', 'a', 'test'])
233         >>> u.insert(['this', 'is', 'a', 'robbery'])
234         >>> u.insert(['this', 'is', 'a', 'walrus'])
235         >>> u.successors(['this', 'is', 'a'])
236         ['test', 'robbery', 'walrus']
237
238         """
239         node = self.__traverse__(item)
240         if node is None:
241             return None
242         return [x for x in node if x != self.end]
243
244     def repr_fancy(self, padding: str, pointer: str, parent: str, node: Any, has_sibling: bool):
245         if node is None:
246             return
247         if node is not self.root:
248             ret = f'\n{padding}{pointer}'
249             if has_sibling:
250                 padding += '│  '
251             else:
252                 padding += '   '
253         else:
254             ret = f'{pointer}'
255
256         child_count = 0
257         for child in node:
258             if child != self.end:
259                 child_count += 1
260
261         for child in node:
262             if child != self.end:
263                 if child_count > 1:
264                     pointer = "├──"
265                     has_sibling = True
266                 else:
267                     pointer = "└──"
268                     has_sibling = False
269                 pointer += f'{child}'
270                 child_count -= 1
271                 ret += self.repr_fancy(padding, pointer, node, node[child], has_sibling)
272         return ret
273
274     def repr_brief(self, node, delimiter):
275         """
276         A friendly string representation of the contents of the Trie.
277
278         >>> t = Trie()
279         >>> t.insert([10, 0, 0, 1])
280         >>> t.insert([10, 0, 0, 2])
281         >>> t.insert([10, 10, 10, 1])
282         >>> t.insert([10, 10, 10, 2])
283         >>> t.repr_brief(t.root, '.')
284         '10.[0.0.[1, 2], 10.10.[1, 2]]'
285
286         """
287         child_count = 0
288         my_rep = ''
289         for child in node:
290             if child != self.end:
291                 child_count += 1
292                 child_rep = self.repr_brief(node[child], delimiter)
293                 if len(child_rep) > 0:
294                     my_rep += str(child) + delimiter + child_rep + ", "
295                 else:
296                     my_rep += str(child) + ", "
297         if len(my_rep) > 1:
298             my_rep = my_rep[:-2]
299         if child_count > 1:
300             my_rep = f'[{my_rep}]'
301         return my_rep
302
303     def __repr__(self):
304         """
305         A friendly string representation of the contents of the Trie.  Under
306         the covers uses repr_fancy.
307
308         >>> t = Trie()
309         >>> t.insert([10, 0, 0, 1])
310         >>> t.insert([10, 0, 0, 2])
311         >>> t.insert([10, 10, 10, 1])
312         >>> t.insert([10, 10, 10, 2])
313         >>> print(t)
314         *
315         └──10
316            ├──0
317            │  └──0
318            │     ├──1
319            │     └──2
320            └──10
321               └──10
322                  ├──1
323                  └──2
324
325         """
326         return self.repr_fancy('', '*', self.root, self.root, False)
327
328
329 if __name__ == '__main__':
330     import doctest
331     doctest.testmod()