762ae3a992fcc860f69a1e2897d73ea1af17c4cb
[pyutils.git] / src / pyutils / collectionz / trie.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """This is a Trie class, see: https://en.wikipedia.org/wiki/Trie.
6
7 It attempts to follow Pythonic container patterns.  See doctests
8 for examples.
9
10 """
11
12 import logging
13 from typing import Any, Generator, Sequence
14
15 logger = logging.getLogger(__name__)
16
17
18 class Trie(object):
19     """
20     This is a Trie class, see: https://en.wikipedia.org/wiki/Trie.
21
22     It attempts to follow Pythonic container patterns.  See doctests
23     for examples.
24
25     """
26
27     def __init__(self):
28         self.root = {}
29         self.end = "~END~"
30         self.length = 0
31         self.viz = ''
32         self.content_generator: Generator[str] = None
33
34     def insert(self, item: Sequence[Any]):
35         """
36         Insert an item.
37
38         >>> t = Trie()
39         >>> t.insert('test')
40         >>> t.__contains__('test')
41         True
42
43         """
44         current = self.root
45         for child in item:
46             current = current.setdefault(child, {})
47         current[self.end] = self.end
48         self.length += 1
49
50     def __contains__(self, item: Sequence[Any]) -> bool:
51         """
52         Check whether an item is in the Trie.
53
54         >>> t = Trie()
55         >>> t.insert('test')
56         >>> t.__contains__('test')
57         True
58         >>> t.__contains__('testing')
59         False
60         >>> 'test' in t
61         True
62
63         """
64         current = self.__traverse__(item)
65         if current is None:
66             return False
67         else:
68             return self.end in current
69
70     def contains_prefix(self, item: Sequence[Any]):
71         """
72         Check whether a prefix is in the Trie.  The prefix may or may not
73         be a full item.
74
75         >>> t = Trie()
76         >>> t.insert('tricycle')
77         >>> t.contains_prefix('tri')
78         True
79         >>> t.contains_prefix('tricycle')
80         True
81         >>> t.contains_prefix('triad')
82         False
83
84         """
85         current = self.__traverse__(item)
86         return current is not None
87
88     def __traverse__(self, item: Sequence[Any]):
89         current = self.root
90         for child in item:
91             if child in current:
92                 current = current[child]
93             else:
94                 return None
95         return current
96
97     def __getitem__(self, item: Sequence[Any]):
98         """Given an item, return its Trie node which contains all
99         of the successor (child) node pointers.  If the item is not
100         a node in the Trie, raise a KeyError.
101
102         >>> t = Trie()
103         >>> t.insert('test')
104         >>> t.insert('testicle')
105         >>> t.insert('tessera')
106         >>> t.insert('tesack')
107         >>> t['tes']
108         {'t': {'~END~': '~END~', 'i': {'c': {'l': {'e': {'~END~': '~END~'}}}}}, 's': {'e': {'r': {'a': {'~END~': '~END~'}}}}, 'a': {'c': {'k': {'~END~': '~END~'}}}}
109
110         """
111         ret = self.__traverse__(item)
112         if ret is None:
113             raise KeyError(f"Node '{item}' is not in the trie")
114         return ret
115
116     def delete_recursively(self, node, item: Sequence[Any]) -> bool:
117         if len(item) == 1:
118             del node[item]
119             if len(node) == 0 and node is not self.root:
120                 del node
121                 return True
122             else:
123                 return False
124         else:
125             car = item[0]
126             cdr = item[1:]
127             lower = node[car]
128             if self.delete_recursively(lower, cdr):
129                 return self.delete_recursively(node, car)
130             return False
131
132     def __delitem__(self, item: Sequence[Any]):
133         """
134         Delete an item from the Trie.
135
136         >>> t = Trie()
137         >>> t.insert('test')
138         >>> t.insert('tess')
139         >>> t.insert('tessel')
140         >>> len(t)
141         3
142         >>> t.root
143         {'t': {'e': {'s': {'t': {'~END~': '~END~'}, 's': {'~END~': '~END~', 'e': {'l': {'~END~': '~END~'}}}}}}}
144         >>> t.__delitem__('test')
145         >>> len(t)
146         2
147         >>> t.root
148         {'t': {'e': {'s': {'s': {'~END~': '~END~', 'e': {'l': {'~END~': '~END~'}}}}}}}
149         >>> for x in t:
150         ...     print(x)
151         tess
152         tessel
153         >>> t.__delitem__('tessel')
154         >>> len(t)
155         1
156         >>> t.root
157         {'t': {'e': {'s': {'s': {'~END~': '~END~'}}}}}
158         >>> for x in t:
159         ...     print(x)
160         tess
161         >>> t.__delitem__('tess')
162         >>> len(t)
163         0
164         >>> t.root
165         {}
166         >>> t.insert('testy')
167         >>> len(t)
168         1
169
170         """
171         if item not in self:
172             raise KeyError(f"Node '{item}' is not in the trie")
173         self.delete_recursively(self.root, item)
174         self.length -= 1
175
176     def __len__(self):
177         """
178         Returns a count of the Trie's item population.
179
180         >>> t = Trie()
181         >>> len(t)
182         0
183         >>> t.insert('test')
184         >>> len(t)
185         1
186         >>> t.insert('testicle')
187         >>> len(t)
188         2
189
190         """
191         return self.length
192
193     def __iter__(self):
194         self.content_generator = self.generate_recursively(self.root, '')
195         return self
196
197     def generate_recursively(self, node, path: Sequence[Any]):
198         """
199         Generate items in the trie one by one.
200
201         >>> t = Trie()
202         >>> t.insert('test')
203         >>> t.insert('tickle')
204         >>> for item in t.generate_recursively(t.root, ''):
205         ...     print(item)
206         test
207         tickle
208
209         """
210         for child in node:
211             if child == self.end:
212                 yield path
213             else:
214                 yield from self.generate_recursively(node[child], path + child)
215
216     def __next__(self):
217         """
218         Iterate through the contents of the trie.
219
220         >>> t = Trie()
221         >>> t.insert('test')
222         >>> t.insert('tickle')
223         >>> for item in t:
224         ...     print(item)
225         test
226         tickle
227
228         """
229         ret = next(self.content_generator)
230         if ret is not None:
231             return ret
232         raise StopIteration
233
234     def successors(self, item: Sequence[Any]):
235         """
236         Return a list of the successors of an item.
237
238         >>> t = Trie()
239         >>> t.insert('what')
240         >>> t.insert('who')
241         >>> t.insert('when')
242         >>> t.successors('wh')
243         ['a', 'o', 'e']
244
245         >>> u = Trie()
246         >>> u.insert(['this', 'is', 'a', 'test'])
247         >>> u.insert(['this', 'is', 'a', 'robbery'])
248         >>> u.insert(['this', 'is', 'a', 'walrus'])
249         >>> u.successors(['this', 'is', 'a'])
250         ['test', 'robbery', 'walrus']
251
252         """
253         node = self.__traverse__(item)
254         if node is None:
255             return None
256         return [x for x in node if x != self.end]
257
258     def _repr_fancy(
259         self,
260         padding: str,
261         pointer: str,
262         node: Any,
263         has_sibling: bool,
264     ) -> str:
265         """
266         Helper that return a fancy representation of the Trie:
267         """
268         if node is None:
269             return ''
270         if node is not self.root:
271             ret = f'\n{padding}{pointer}'
272             if has_sibling:
273                 padding += '│  '
274             else:
275                 padding += '   '
276         else:
277             ret = f'{pointer}'
278
279         child_count = 0
280         for child in node:
281             if child != self.end:
282                 child_count += 1
283
284         for child in node:
285             if child != self.end:
286                 if child_count > 1:
287                     pointer = "├──"
288                     has_sibling = True
289                 else:
290                     pointer = "└──"
291                     has_sibling = False
292                 pointer += f'{child}'
293                 child_count -= 1
294                 ret += self._repr_fancy(padding, pointer, node[child], has_sibling)
295         return ret
296
297     def repr_brief(self, node, delimiter):
298         """
299         A friendly string representation of the contents of the Trie.
300
301         >>> t = Trie()
302         >>> t.insert([10, 0, 0, 1])
303         >>> t.insert([10, 0, 0, 2])
304         >>> t.insert([10, 10, 10, 1])
305         >>> t.insert([10, 10, 10, 2])
306         >>> t.repr_brief(t.root, '.')
307         '10.[0.0.[1,2],10.10.[1,2]]'
308
309         """
310         child_count = 0
311         my_rep = ''
312         for child in node:
313             if child != self.end:
314                 child_count += 1
315                 child_rep = self.repr_brief(node[child], delimiter)
316                 if len(child_rep) > 0:
317                     my_rep += str(child) + delimiter + child_rep + ","
318                 else:
319                     my_rep += str(child) + ","
320         if len(my_rep) > 1:
321             my_rep = my_rep[:-1]
322         if child_count > 1:
323             my_rep = f'[{my_rep}]'
324         return my_rep
325
326     def __repr__(self):
327         """
328         A friendly string representation of the contents of the Trie.  Under
329         the covers uses _repr_fancy.
330
331         >>> t = Trie()
332         >>> t.insert([10, 0, 0, 1])
333         >>> t.insert([10, 0, 0, 2])
334         >>> t.insert([10, 10, 10, 1])
335         >>> t.insert([10, 10, 10, 2])
336         >>> print(t)
337         *
338         └──10
339            ├──0
340            │  └──0
341            │     ├──1
342            │     └──2
343            └──10
344               └──10
345                  ├──1
346                  └──2
347
348         """
349         return self._repr_fancy('', '*', self.root, False)