2e5e3ce95599811aecb553fe5b18e2ebd97c1434
[pyutils.git] / src / pyutils / collectionz / bst.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """A binary search tree."""
6
7 from typing import Any, Generator, List, Optional
8
9
10 class Node(object):
11     def __init__(self, value: Any) -> None:
12         """Note that value can be anything as long as it is
13         comparable.  Check out @functools.total_ordering.
14
15         """
16         self.left: Optional[Node] = None
17         self.right: Optional[Node] = None
18         self.value = value
19
20
21 class BinarySearchTree(object):
22     def __init__(self):
23         self.root = None
24         self.count = 0
25         self.traverse = None
26
27     def get_root(self) -> Optional[Node]:
28         """:returns the root of the BST."""
29
30         return self.root
31
32     def insert(self, value: Any):
33         """
34         Insert something into the tree.
35
36         >>> t = BinarySearchTree()
37         >>> t.insert(10)
38         >>> t.insert(20)
39         >>> t.insert(5)
40         >>> len(t)
41         3
42
43         >>> t.get_root().value
44         10
45
46         """
47         if self.root is None:
48             self.root = Node(value)
49             self.count = 1
50         else:
51             self._insert(value, self.root)
52
53     def _insert(self, value: Any, node: Node):
54         """Insertion helper"""
55         if value < node.value:
56             if node.left is not None:
57                 self._insert(value, node.left)
58             else:
59                 node.left = Node(value)
60                 self.count += 1
61         else:
62             if node.right is not None:
63                 self._insert(value, node.right)
64             else:
65                 node.right = Node(value)
66                 self.count += 1
67
68     def __getitem__(self, value: Any) -> Optional[Node]:
69         """
70         Find an item in the tree and return its Node.  Returns
71         None if the item is not in the tree.
72
73         >>> t = BinarySearchTree()
74         >>> t[99]
75
76         >>> t.insert(10)
77         >>> t.insert(20)
78         >>> t.insert(5)
79         >>> t[10].value
80         10
81
82         >>> t[99]
83
84         """
85         if self.root is not None:
86             return self._find(value, self.root)
87         return None
88
89     def _find(self, value: Any, node: Node) -> Optional[Node]:
90         """Find helper"""
91         if value == node.value:
92             return node
93         elif value < node.value and node.left is not None:
94             return self._find(value, node.left)
95         elif value > node.value and node.right is not None:
96             return self._find(value, node.right)
97         return None
98
99     def _parent_path(
100         self, current: Optional[Node], target: Node
101     ) -> List[Optional[Node]]:
102         if current is None:
103             return [None]
104         ret: List[Optional[Node]] = [current]
105         if target.value == current.value:
106             return ret
107         elif target.value < current.value:
108             ret.extend(self._parent_path(current.left, target))
109             return ret
110         else:
111             assert target.value > current.value
112             ret.extend(self._parent_path(current.right, target))
113             return ret
114
115     def parent_path(self, node: Node) -> List[Optional[Node]]:
116         """Return a list of nodes representing the path from
117         the tree's root to the node argument.  If the node does
118         not exist in the tree for some reason, the last element
119         on the path will be None but the path will indicate the
120         ancestor path of that node were it inserted.
121
122         >>> t = BinarySearchTree()
123         >>> t.insert(50)
124         >>> t.insert(75)
125         >>> t.insert(25)
126         >>> t.insert(12)
127         >>> t.insert(33)
128         >>> t.insert(4)
129         >>> t.insert(88)
130         >>> t
131         50
132         ├──25
133         │  ├──12
134         │  │  └──4
135         │  └──33
136         └──75
137            └──88
138
139         >>> n = t[4]
140         >>> for x in t.parent_path(n):
141         ...     print(x.value)
142         50
143         25
144         12
145         4
146
147         >>> del t[4]
148         >>> for x in t.parent_path(n):
149         ...     if x is not None:
150         ...         print(x.value)
151         ...     else:
152         ...         print(x)
153         50
154         25
155         12
156         None
157
158         """
159         return self._parent_path(self.root, node)
160
161     def __delitem__(self, value: Any) -> bool:
162         """
163         Delete an item from the tree and preserve the BST property.
164
165         >>> t = BinarySearchTree()
166         >>> t.insert(50)
167         >>> t.insert(75)
168         >>> t.insert(25)
169         >>> t.insert(66)
170         >>> t.insert(22)
171         >>> t.insert(13)
172         >>> t.insert(85)
173         >>> t
174         50
175         ├──25
176         │  └──22
177         │     └──13
178         └──75
179            ├──66
180            └──85
181
182         >>> for value in t.iterate_inorder():
183         ...     print(value)
184         13
185         22
186         25
187         50
188         66
189         75
190         85
191
192         >>> del t[22]  # Note: bool result is discarded
193
194         >>> for value in t.iterate_inorder():
195         ...     print(value)
196         13
197         25
198         50
199         66
200         75
201         85
202
203         >>> t.__delitem__(13)
204         True
205         >>> for value in t.iterate_inorder():
206         ...     print(value)
207         25
208         50
209         66
210         75
211         85
212
213         >>> t.__delitem__(75)
214         True
215         >>> for value in t.iterate_inorder():
216         ...     print(value)
217         25
218         50
219         66
220         85
221         >>> t
222         50
223         ├──25
224         └──85
225            └──66
226
227         >>> t.__delitem__(99)
228         False
229
230         """
231         if self.root is not None:
232             ret = self._delete(value, None, self.root)
233             if ret:
234                 self.count -= 1
235                 if self.count == 0:
236                     self.root = None
237             return ret
238         return False
239
240     def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool:
241         """Delete helper"""
242         if node.value == value:
243             # Deleting a leaf node
244             if node.left is None and node.right is None:
245                 if parent is not None:
246                     if parent.left == node:
247                         parent.left = None
248                     else:
249                         assert parent.right == node
250                         parent.right = None
251                 return True
252
253             # Node only has a right.
254             elif node.left is None:
255                 assert node.right is not None
256                 if parent is not None:
257                     if parent.left == node:
258                         parent.left = node.right
259                     else:
260                         assert parent.right == node
261                         parent.right = node.right
262                 return True
263
264             # Node only has a left.
265             elif node.right is None:
266                 assert node.left is not None
267                 if parent is not None:
268                     if parent.left == node:
269                         parent.left = node.left
270                     else:
271                         assert parent.right == node
272                         parent.right = node.left
273                 return True
274
275             # Node has both a left and right.
276             else:
277                 assert node.left is not None and node.right is not None
278                 descendent = node.right
279                 while descendent.left is not None:
280                     descendent = descendent.left
281                 node.value = descendent.value
282                 return self._delete(node.value, node, node.right)
283         elif value < node.value and node.left is not None:
284             return self._delete(value, node, node.left)
285         elif value > node.value and node.right is not None:
286             return self._delete(value, node, node.right)
287         return False
288
289     def __len__(self):
290         """
291         Returns the count of items in the tree.
292
293         >>> t = BinarySearchTree()
294         >>> len(t)
295         0
296         >>> t.insert(50)
297         >>> len(t)
298         1
299         >>> t.__delitem__(50)
300         True
301         >>> len(t)
302         0
303         >>> t.insert(75)
304         >>> t.insert(25)
305         >>> t.insert(66)
306         >>> t.insert(22)
307         >>> t.insert(13)
308         >>> t.insert(85)
309         >>> len(t)
310         6
311
312         """
313         return self.count
314
315     def __contains__(self, value: Any) -> bool:
316         """
317         Returns True if the item is in the tree; False otherwise.
318         """
319         return self.__getitem__(value) is not None
320
321     def _iterate_preorder(self, node: Node):
322         yield node.value
323         if node.left is not None:
324             yield from self._iterate_preorder(node.left)
325         if node.right is not None:
326             yield from self._iterate_preorder(node.right)
327
328     def _iterate_inorder(self, node: Node):
329         if node.left is not None:
330             yield from self._iterate_inorder(node.left)
331         yield node.value
332         if node.right is not None:
333             yield from self._iterate_inorder(node.right)
334
335     def _iterate_postorder(self, node: Node):
336         if node.left is not None:
337             yield from self._iterate_postorder(node.left)
338         if node.right is not None:
339             yield from self._iterate_postorder(node.right)
340         yield node.value
341
342     def iterate_preorder(self):
343         """
344         Yield the tree's items in a preorder traversal sequence.
345
346         >>> t = BinarySearchTree()
347         >>> t.insert(50)
348         >>> t.insert(75)
349         >>> t.insert(25)
350         >>> t.insert(66)
351         >>> t.insert(22)
352         >>> t.insert(13)
353
354         >>> for value in t.iterate_preorder():
355         ...     print(value)
356         50
357         25
358         22
359         13
360         75
361         66
362
363         """
364         if self.root is not None:
365             yield from self._iterate_preorder(self.root)
366
367     def iterate_inorder(self):
368         """
369         Yield the tree's items in a preorder traversal sequence.
370
371         >>> t = BinarySearchTree()
372         >>> t.insert(50)
373         >>> t.insert(75)
374         >>> t.insert(25)
375         >>> t.insert(66)
376         >>> t.insert(22)
377         >>> t.insert(13)
378         >>> t.insert(24)
379         >>> t
380         50
381         ├──25
382         │  └──22
383         │     ├──13
384         │     └──24
385         └──75
386            └──66
387
388         >>> for value in t.iterate_inorder():
389         ...     print(value)
390         13
391         22
392         24
393         25
394         50
395         66
396         75
397
398         """
399         if self.root is not None:
400             yield from self._iterate_inorder(self.root)
401
402     def iterate_postorder(self):
403         """
404         Yield the tree's items in a preorder traversal sequence.
405
406         >>> t = BinarySearchTree()
407         >>> t.insert(50)
408         >>> t.insert(75)
409         >>> t.insert(25)
410         >>> t.insert(66)
411         >>> t.insert(22)
412         >>> t.insert(13)
413
414         >>> for value in t.iterate_postorder():
415         ...     print(value)
416         13
417         22
418         25
419         66
420         75
421         50
422
423         """
424         if self.root is not None:
425             yield from self._iterate_postorder(self.root)
426
427     def _iterate_leaves(self, node: Node):
428         if node.left is not None:
429             yield from self._iterate_leaves(node.left)
430         if node.right is not None:
431             yield from self._iterate_leaves(node.right)
432         if node.left is None and node.right is None:
433             yield node.value
434
435     def iterate_leaves(self):
436         """
437         Iterate only the leaf nodes in the tree.
438
439         >>> t = BinarySearchTree()
440         >>> t.insert(50)
441         >>> t.insert(75)
442         >>> t.insert(25)
443         >>> t.insert(66)
444         >>> t.insert(22)
445         >>> t.insert(13)
446
447         >>> for value in t.iterate_leaves():
448         ...     print(value)
449         13
450         66
451
452         """
453         if self.root is not None:
454             yield from self._iterate_leaves(self.root)
455
456     def _iterate_by_depth(self, node: Node, depth: int):
457         if depth == 0:
458             yield node.value
459         else:
460             assert depth > 0
461             if node.left is not None:
462                 yield from self._iterate_by_depth(node.left, depth - 1)
463             if node.right is not None:
464                 yield from self._iterate_by_depth(node.right, depth - 1)
465
466     def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
467         """
468         Iterate only the leaf nodes in the tree.
469
470         >>> t = BinarySearchTree()
471         >>> t.insert(50)
472         >>> t.insert(75)
473         >>> t.insert(25)
474         >>> t.insert(66)
475         >>> t.insert(22)
476         >>> t.insert(13)
477
478         >>> for value in t.iterate_nodes_by_depth(2):
479         ...     print(value)
480         22
481         66
482
483         >>> for value in t.iterate_nodes_by_depth(3):
484         ...     print(value)
485         13
486
487         """
488         if self.root is not None:
489             yield from self._iterate_by_depth(self.root, depth)
490
491     def get_next_node(self, node: Node) -> Node:
492         """
493         Given a tree node, get the next greater node in the tree.
494
495         >>> t = BinarySearchTree()
496         >>> t.insert(50)
497         >>> t.insert(75)
498         >>> t.insert(25)
499         >>> t.insert(66)
500         >>> t.insert(22)
501         >>> t.insert(13)
502         >>> t.insert(23)
503         >>> t
504         50
505         ├──25
506         │  └──22
507         │     ├──13
508         │     └──23
509         └──75
510            └──66
511
512         >>> n = t[23]
513         >>> t.get_next_node(n).value
514         25
515
516         >>> n = t[50]
517         >>> t.get_next_node(n).value
518         66
519
520         """
521         if node.right is not None:
522             x = node.right
523             while x.left is not None:
524                 x = x.left
525             return x
526
527         path = self.parent_path(node)
528         assert path[-1] is not None
529         assert path[-1] == node
530         path = path[:-1]
531         path.reverse()
532         for ancestor in path:
533             assert ancestor is not None
534             if node != ancestor.right:
535                 return ancestor
536             node = ancestor
537         raise Exception()
538
539     def _depth(self, node: Node, sofar: int) -> int:
540         depth_left = sofar + 1
541         depth_right = sofar + 1
542         if node.left is not None:
543             depth_left = self._depth(node.left, sofar + 1)
544         if node.right is not None:
545             depth_right = self._depth(node.right, sofar + 1)
546         return max(depth_left, depth_right)
547
548     def depth(self) -> int:
549         """
550         Returns the max height (depth) of the tree in plies (edge distance
551         from root).
552
553         >>> t = BinarySearchTree()
554         >>> t.depth()
555         0
556
557         >>> t.insert(50)
558         >>> t.depth()
559         1
560
561         >>> t.insert(65)
562         >>> t.depth()
563         2
564
565         >>> t.insert(33)
566         >>> t.depth()
567         2
568
569         >>> t.insert(2)
570         >>> t.insert(1)
571         >>> t.depth()
572         4
573
574         """
575         if self.root is None:
576             return 0
577         return self._depth(self.root, 0)
578
579     def height(self) -> int:
580         """Returns the height (i.e. max depth) of the tree"""
581         return self.depth()
582
583     def repr_traverse(
584         self,
585         padding: str,
586         pointer: str,
587         node: Optional[Node],
588         has_right_sibling: bool,
589     ) -> str:
590         if node is not None:
591             viz = f'\n{padding}{pointer}{node.value}'
592             if has_right_sibling:
593                 padding += "│  "
594             else:
595                 padding += '   '
596
597             pointer_right = "└──"
598             if node.right is not None:
599                 pointer_left = "├──"
600             else:
601                 pointer_left = "└──"
602
603             viz += self.repr_traverse(
604                 padding, pointer_left, node.left, node.right is not None
605             )
606             viz += self.repr_traverse(padding, pointer_right, node.right, False)
607             return viz
608         return ""
609
610     def __repr__(self):
611         """
612         Draw the tree in ASCII.
613
614         >>> t = BinarySearchTree()
615         >>> t.insert(50)
616         >>> t.insert(25)
617         >>> t.insert(75)
618         >>> t.insert(12)
619         >>> t.insert(33)
620         >>> t.insert(88)
621         >>> t.insert(55)
622         >>> t
623         50
624         ├──25
625         │  ├──12
626         │  └──33
627         └──75
628            ├──55
629            └──88
630         """
631         if self.root is None:
632             return ""
633
634         ret = f'{self.root.value}'
635         pointer_right = "└──"
636         if self.root.right is None:
637             pointer_left = "└──"
638         else:
639             pointer_left = "├──"
640
641         ret += self.repr_traverse(
642             '', pointer_left, self.root.left, self.root.left is not None
643         )
644         ret += self.repr_traverse('', pointer_right, self.root.right, False)
645         return ret