Adds shuffle/scramble to list_utils.
[python_utils.git] / collect / trie.py
1 #!/usr/bin/env python3
2
3 """This is a Trie class, see: https://en.wikipedia.org/wiki/Trie.
4
5    It attempts to follow Pythonic container patterns.  See doctests
6    for examples."""
7
8 from typing import Any, Generator, Sequence
9
10
11 class Trie(object):
12     """
13     This is a Trie class, see: https://en.wikipedia.org/wiki/Trie.
14
15     It attempts to follow Pythonic container patterns.  See doctests
16     for examples.
17
18     """
19
20     def __init__(self):
21         self.root = {}
22         self.end = "~END~"
23         self.length = 0
24         self.viz = ''
25         self.content_generator: Generator[str] = None
26
27     def insert(self, item: Sequence[Any]):
28         """
29         Insert an item.
30
31         >>> t = Trie()
32         >>> t.insert('test')
33         >>> t.__contains__('test')
34         True
35
36         """
37         current = self.root
38         for child in item:
39             current = current.setdefault(child, {})
40         current[self.end] = self.end
41         self.length += 1
42
43     def __contains__(self, item: Sequence[Any]) -> bool:
44         """
45         Check whether an item is in the Trie.
46
47         >>> t = Trie()
48         >>> t.insert('test')
49         >>> t.__contains__('test')
50         True
51         >>> t.__contains__('testing')
52         False
53         >>> 'test' in t
54         True
55
56         """
57         current = self.__traverse__(item)
58         if current is None:
59             return False
60         else:
61             return self.end in current
62
63     def contains_prefix(self, item: Sequence[Any]):
64         """
65         Check whether a prefix is in the Trie.  The prefix may or may not
66         be a full item.
67
68         >>> t = Trie()
69         >>> t.insert('testicle')
70         >>> t.contains_prefix('test')
71         True
72         >>> t.contains_prefix('testicle')
73         True
74         >>> t.contains_prefix('tessel')
75         False
76
77         """
78         current = self.__traverse__(item)
79         return current is not None
80
81     def __traverse__(self, item: Sequence[Any]):
82         current = self.root
83         for child in item:
84             if child in current:
85                 current = current[child]
86             else:
87                 return None
88         return current
89
90     def __getitem__(self, item: Sequence[Any]):
91         """Given an item, return its Trie node which contains all
92         of the successor (child) node pointers.  If the item is not
93         a node in the Trie, raise a KeyError.
94
95         >>> t = Trie()
96         >>> t.insert('test')
97         >>> t.insert('testicle')
98         >>> t.insert('tessera')
99         >>> t.insert('tesack')
100         >>> t['tes']
101         {'t': {'~END~': '~END~', 'i': {'c': {'l': {'e': {'~END~': '~END~'}}}}}, 's': {'e': {'r': {'a': {'~END~': '~END~'}}}}, 'a': {'c': {'k': {'~END~': '~END~'}}}}
102
103         """
104         ret = self.__traverse__(item)
105         if ret is None:
106             raise KeyError(f"Node '{item}' is not in the trie")
107         return ret
108
109     def delete_recursively(self, node, item: Sequence[Any]) -> bool:
110         if len(item) == 1:
111             del node[item]
112             if len(node) == 0 and node is not self.root:
113                 del node
114                 return True
115             else:
116                 return False
117         else:
118             car = item[0]
119             cdr = item[1:]
120             lower = node[car]
121             if self.delete_recursively(lower, cdr):
122                 return self.delete_recursively(node, car)
123             return False
124
125     def __delitem__(self, item: Sequence[Any]):
126         """
127         Delete an item from the Trie.
128
129         >>> t = Trie()
130         >>> t.insert('test')
131         >>> t.insert('tess')
132         >>> t.insert('tessel')
133         >>> len(t)
134         3
135         >>> t.root
136         {'t': {'e': {'s': {'t': {'~END~': '~END~'}, 's': {'~END~': '~END~', 'e': {'l': {'~END~': '~END~'}}}}}}}
137         >>> t.__delitem__('test')
138         >>> len(t)
139         2
140         >>> t.root
141         {'t': {'e': {'s': {'s': {'~END~': '~END~', 'e': {'l': {'~END~': '~END~'}}}}}}}
142         >>> for x in t:
143         ...     print(x)
144         tess
145         tessel
146         >>> t.__delitem__('tessel')
147         >>> len(t)
148         1
149         >>> t.root
150         {'t': {'e': {'s': {'s': {'~END~': '~END~'}}}}}
151         >>> for x in t:
152         ...     print(x)
153         tess
154         >>> t.__delitem__('tess')
155         >>> len(t)
156         0
157         >>> t.root
158         {}
159         >>> t.insert('testy')
160         >>> len(t)
161         1
162
163         """
164         if item not in self:
165             raise KeyError(f"Node '{item}' is not in the trie")
166         self.delete_recursively(self.root, item)
167         self.length -= 1
168
169     def __len__(self):
170         """
171         Returns a count of the Trie's item population.
172
173         >>> t = Trie()
174         >>> len(t)
175         0
176         >>> t.insert('test')
177         >>> len(t)
178         1
179         >>> t.insert('testicle')
180         >>> len(t)
181         2
182
183         """
184         return self.length
185
186     def __iter__(self):
187         self.content_generator = self.generate_recursively(self.root, '')
188         return self
189
190     def generate_recursively(self, node, path: Sequence[Any]):
191         """
192         Generate items in the trie one by one.
193
194         >>> t = Trie()
195         >>> t.insert('test')
196         >>> t.insert('tickle')
197         >>> for item in t.generate_recursively(t.root, ''):
198         ...     print(item)
199         test
200         tickle
201
202         """
203         for child in node:
204             if child == self.end:
205                 yield path
206             else:
207                 yield from self.generate_recursively(node[child], path + child)
208
209     def __next__(self):
210         """
211         Iterate through the contents of the trie.
212
213         >>> t = Trie()
214         >>> t.insert('test')
215         >>> t.insert('tickle')
216         >>> for item in t:
217         ...     print(item)
218         test
219         tickle
220
221         """
222         ret = next(self.content_generator)
223         if ret is not None:
224             return ret
225         raise StopIteration
226
227     def successors(self, item: Sequence[Any]):
228         """
229         Return a list of the successors of an item.
230
231         >>> t = Trie()
232         >>> t.insert('what')
233         >>> t.insert('who')
234         >>> t.insert('when')
235         >>> t.successors('wh')
236         ['a', 'o', 'e']
237
238         >>> u = Trie()
239         >>> u.insert(['this', 'is', 'a', 'test'])
240         >>> u.insert(['this', 'is', 'a', 'robbery'])
241         >>> u.insert(['this', 'is', 'a', 'walrus'])
242         >>> u.successors(['this', 'is', 'a'])
243         ['test', 'robbery', 'walrus']
244
245         """
246         node = self.__traverse__(item)
247         if node is None:
248             return None
249         return [x for x in node if x != self.end]
250
251     def repr_fancy(
252         self,
253         padding: str,
254         pointer: str,
255         node: Any,
256         has_sibling: bool,
257     ):
258         if node is None:
259             return ''
260         if node is not self.root:
261             ret = f'\n{padding}{pointer}'
262             if has_sibling:
263                 padding += '│  '
264             else:
265                 padding += '   '
266         else:
267             ret = f'{pointer}'
268
269         child_count = 0
270         for child in node:
271             if child != self.end:
272                 child_count += 1
273
274         for child in node:
275             if child != self.end:
276                 if child_count > 1:
277                     pointer = "├──"
278                     has_sibling = True
279                 else:
280                     pointer = "└──"
281                     has_sibling = False
282                 pointer += f'{child}'
283                 child_count -= 1
284                 ret += self.repr_fancy(padding, pointer, node[child], has_sibling)
285         return ret
286
287     def repr_brief(self, node, delimiter):
288         """
289         A friendly string representation of the contents of the Trie.
290
291         >>> t = Trie()
292         >>> t.insert([10, 0, 0, 1])
293         >>> t.insert([10, 0, 0, 2])
294         >>> t.insert([10, 10, 10, 1])
295         >>> t.insert([10, 10, 10, 2])
296         >>> t.repr_brief(t.root, '.')
297         '10.[0.0.[1,2],10.10.[1,2]]'
298
299         """
300         child_count = 0
301         my_rep = ''
302         for child in node:
303             if child != self.end:
304                 child_count += 1
305                 child_rep = self.repr_brief(node[child], delimiter)
306                 if len(child_rep) > 0:
307                     my_rep += str(child) + delimiter + child_rep + ","
308                 else:
309                     my_rep += str(child) + ","
310         if len(my_rep) > 1:
311             my_rep = my_rep[:-1]
312         if child_count > 1:
313             my_rep = f'[{my_rep}]'
314         return my_rep
315
316     def __repr__(self):
317         """
318         A friendly string representation of the contents of the Trie.  Under
319         the covers uses repr_fancy.
320
321         >>> t = Trie()
322         >>> t.insert([10, 0, 0, 1])
323         >>> t.insert([10, 0, 0, 2])
324         >>> t.insert([10, 10, 10, 1])
325         >>> t.insert([10, 10, 10, 2])
326         >>> print(t)
327         *
328         └──10
329            ├──0
330            │  └──0
331            │     ├──1
332            │     └──2
333            └──10
334               └──10
335                  ├──1
336                  └──2
337
338         """
339         return self.repr_fancy('', '*', self.root, False)
340
341
342 if __name__ == '__main__':
343     import doctest
344
345     doctest.testmod()