Since this thing is on the innerwebs I suppose it should have a
[python_utils.git] / collect / trie.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """This is a Trie class, see: https://en.wikipedia.org/wiki/Trie.
6
7 It attempts to follow Pythonic container patterns.  See doctests
8 for examples.
9
10 """
11
12 from typing import Any, Generator, Sequence
13
14
15 class Trie(object):
16     """
17     This is a Trie class, see: https://en.wikipedia.org/wiki/Trie.
18
19     It attempts to follow Pythonic container patterns.  See doctests
20     for examples.
21
22     """
23
24     def __init__(self):
25         self.root = {}
26         self.end = "~END~"
27         self.length = 0
28         self.viz = ''
29         self.content_generator: Generator[str] = None
30
31     def insert(self, item: Sequence[Any]):
32         """
33         Insert an item.
34
35         >>> t = Trie()
36         >>> t.insert('test')
37         >>> t.__contains__('test')
38         True
39
40         """
41         current = self.root
42         for child in item:
43             current = current.setdefault(child, {})
44         current[self.end] = self.end
45         self.length += 1
46
47     def __contains__(self, item: Sequence[Any]) -> bool:
48         """
49         Check whether an item is in the Trie.
50
51         >>> t = Trie()
52         >>> t.insert('test')
53         >>> t.__contains__('test')
54         True
55         >>> t.__contains__('testing')
56         False
57         >>> 'test' in t
58         True
59
60         """
61         current = self.__traverse__(item)
62         if current is None:
63             return False
64         else:
65             return self.end in current
66
67     def contains_prefix(self, item: Sequence[Any]):
68         """
69         Check whether a prefix is in the Trie.  The prefix may or may not
70         be a full item.
71
72         >>> t = Trie()
73         >>> t.insert('testicle')
74         >>> t.contains_prefix('test')
75         True
76         >>> t.contains_prefix('testicle')
77         True
78         >>> t.contains_prefix('tessel')
79         False
80
81         """
82         current = self.__traverse__(item)
83         return current is not None
84
85     def __traverse__(self, item: Sequence[Any]):
86         current = self.root
87         for child in item:
88             if child in current:
89                 current = current[child]
90             else:
91                 return None
92         return current
93
94     def __getitem__(self, item: Sequence[Any]):
95         """Given an item, return its Trie node which contains all
96         of the successor (child) node pointers.  If the item is not
97         a node in the Trie, raise a KeyError.
98
99         >>> t = Trie()
100         >>> t.insert('test')
101         >>> t.insert('testicle')
102         >>> t.insert('tessera')
103         >>> t.insert('tesack')
104         >>> t['tes']
105         {'t': {'~END~': '~END~', 'i': {'c': {'l': {'e': {'~END~': '~END~'}}}}}, 's': {'e': {'r': {'a': {'~END~': '~END~'}}}}, 'a': {'c': {'k': {'~END~': '~END~'}}}}
106
107         """
108         ret = self.__traverse__(item)
109         if ret is None:
110             raise KeyError(f"Node '{item}' is not in the trie")
111         return ret
112
113     def delete_recursively(self, node, item: Sequence[Any]) -> bool:
114         if len(item) == 1:
115             del node[item]
116             if len(node) == 0 and node is not self.root:
117                 del node
118                 return True
119             else:
120                 return False
121         else:
122             car = item[0]
123             cdr = item[1:]
124             lower = node[car]
125             if self.delete_recursively(lower, cdr):
126                 return self.delete_recursively(node, car)
127             return False
128
129     def __delitem__(self, item: Sequence[Any]):
130         """
131         Delete an item from the Trie.
132
133         >>> t = Trie()
134         >>> t.insert('test')
135         >>> t.insert('tess')
136         >>> t.insert('tessel')
137         >>> len(t)
138         3
139         >>> t.root
140         {'t': {'e': {'s': {'t': {'~END~': '~END~'}, 's': {'~END~': '~END~', 'e': {'l': {'~END~': '~END~'}}}}}}}
141         >>> t.__delitem__('test')
142         >>> len(t)
143         2
144         >>> t.root
145         {'t': {'e': {'s': {'s': {'~END~': '~END~', 'e': {'l': {'~END~': '~END~'}}}}}}}
146         >>> for x in t:
147         ...     print(x)
148         tess
149         tessel
150         >>> t.__delitem__('tessel')
151         >>> len(t)
152         1
153         >>> t.root
154         {'t': {'e': {'s': {'s': {'~END~': '~END~'}}}}}
155         >>> for x in t:
156         ...     print(x)
157         tess
158         >>> t.__delitem__('tess')
159         >>> len(t)
160         0
161         >>> t.root
162         {}
163         >>> t.insert('testy')
164         >>> len(t)
165         1
166
167         """
168         if item not in self:
169             raise KeyError(f"Node '{item}' is not in the trie")
170         self.delete_recursively(self.root, item)
171         self.length -= 1
172
173     def __len__(self):
174         """
175         Returns a count of the Trie's item population.
176
177         >>> t = Trie()
178         >>> len(t)
179         0
180         >>> t.insert('test')
181         >>> len(t)
182         1
183         >>> t.insert('testicle')
184         >>> len(t)
185         2
186
187         """
188         return self.length
189
190     def __iter__(self):
191         self.content_generator = self.generate_recursively(self.root, '')
192         return self
193
194     def generate_recursively(self, node, path: Sequence[Any]):
195         """
196         Generate items in the trie one by one.
197
198         >>> t = Trie()
199         >>> t.insert('test')
200         >>> t.insert('tickle')
201         >>> for item in t.generate_recursively(t.root, ''):
202         ...     print(item)
203         test
204         tickle
205
206         """
207         for child in node:
208             if child == self.end:
209                 yield path
210             else:
211                 yield from self.generate_recursively(node[child], path + child)
212
213     def __next__(self):
214         """
215         Iterate through the contents of the trie.
216
217         >>> t = Trie()
218         >>> t.insert('test')
219         >>> t.insert('tickle')
220         >>> for item in t:
221         ...     print(item)
222         test
223         tickle
224
225         """
226         ret = next(self.content_generator)
227         if ret is not None:
228             return ret
229         raise StopIteration
230
231     def successors(self, item: Sequence[Any]):
232         """
233         Return a list of the successors of an item.
234
235         >>> t = Trie()
236         >>> t.insert('what')
237         >>> t.insert('who')
238         >>> t.insert('when')
239         >>> t.successors('wh')
240         ['a', 'o', 'e']
241
242         >>> u = Trie()
243         >>> u.insert(['this', 'is', 'a', 'test'])
244         >>> u.insert(['this', 'is', 'a', 'robbery'])
245         >>> u.insert(['this', 'is', 'a', 'walrus'])
246         >>> u.successors(['this', 'is', 'a'])
247         ['test', 'robbery', 'walrus']
248
249         """
250         node = self.__traverse__(item)
251         if node is None:
252             return None
253         return [x for x in node if x != self.end]
254
255     def repr_fancy(
256         self,
257         padding: str,
258         pointer: str,
259         node: Any,
260         has_sibling: bool,
261     ):
262         if node is None:
263             return ''
264         if node is not self.root:
265             ret = f'\n{padding}{pointer}'
266             if has_sibling:
267                 padding += '│  '
268             else:
269                 padding += '   '
270         else:
271             ret = f'{pointer}'
272
273         child_count = 0
274         for child in node:
275             if child != self.end:
276                 child_count += 1
277
278         for child in node:
279             if child != self.end:
280                 if child_count > 1:
281                     pointer = "├──"
282                     has_sibling = True
283                 else:
284                     pointer = "└──"
285                     has_sibling = False
286                 pointer += f'{child}'
287                 child_count -= 1
288                 ret += self.repr_fancy(padding, pointer, node[child], has_sibling)
289         return ret
290
291     def repr_brief(self, node, delimiter):
292         """
293         A friendly string representation of the contents of the Trie.
294
295         >>> t = Trie()
296         >>> t.insert([10, 0, 0, 1])
297         >>> t.insert([10, 0, 0, 2])
298         >>> t.insert([10, 10, 10, 1])
299         >>> t.insert([10, 10, 10, 2])
300         >>> t.repr_brief(t.root, '.')
301         '10.[0.0.[1,2],10.10.[1,2]]'
302
303         """
304         child_count = 0
305         my_rep = ''
306         for child in node:
307             if child != self.end:
308                 child_count += 1
309                 child_rep = self.repr_brief(node[child], delimiter)
310                 if len(child_rep) > 0:
311                     my_rep += str(child) + delimiter + child_rep + ","
312                 else:
313                     my_rep += str(child) + ","
314         if len(my_rep) > 1:
315             my_rep = my_rep[:-1]
316         if child_count > 1:
317             my_rep = f'[{my_rep}]'
318         return my_rep
319
320     def __repr__(self):
321         """
322         A friendly string representation of the contents of the Trie.  Under
323         the covers uses repr_fancy.
324
325         >>> t = Trie()
326         >>> t.insert([10, 0, 0, 1])
327         >>> t.insert([10, 0, 0, 2])
328         >>> t.insert([10, 10, 10, 1])
329         >>> t.insert([10, 10, 10, 2])
330         >>> print(t)
331         *
332         └──10
333            ├──0
334            │  └──0
335            │     ├──1
336            │     └──2
337            └──10
338               └──10
339                  ├──1
340                  └──2
341
342         """
343         return self.repr_fancy('', '*', self.root, False)
344
345
346 if __name__ == '__main__':
347     import doctest
348
349     doctest.testmod()