1efed52838cb852259f7881270c208d3fb0f50ad
[pyutils.git] / src / pyutils / collectionz / bst.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """A binary search tree implementation."""
6
7 from typing import Any, Generator, List, Optional
8
9
10 class Node(object):
11     def __init__(self, value: Any) -> None:
12         """
13         A BST node.  Note that value can be anything as long as it
14         is comparable.  Check out :meth:`functools.total_ordering`
15         (https://docs.python.org/3/library/functools.html#functools.total_ordering)
16
17         Args:
18             value: a reference to the value of the node.
19         """
20         self.left: Optional[Node] = None
21         self.right: Optional[Node] = None
22         self.value = value
23
24
25 class BinarySearchTree(object):
26     def __init__(self):
27         self.root = None
28         self.count = 0
29         self.traverse = None
30
31     def get_root(self) -> Optional[Node]:
32         """
33         Returns:
34             The root of the BST
35         """
36
37         return self.root
38
39     def _on_insert(self, parent: Optional[Node], new: Node) -> None:
40         """This is called immediately _after_ a new node is inserted."""
41         pass
42
43     def insert(self, value: Any) -> None:
44         """
45         Insert something into the tree.
46
47         Args:
48             value: the value to be inserted.
49
50         >>> t = BinarySearchTree()
51         >>> t.insert(10)
52         >>> t.insert(20)
53         >>> t.insert(5)
54         >>> len(t)
55         3
56
57         >>> t.get_root().value
58         10
59
60         """
61         if self.root is None:
62             self.root = Node(value)
63             self.count = 1
64             self._on_insert(None, self.root)
65         else:
66             self._insert(value, self.root)
67
68     def _insert(self, value: Any, node: Node):
69         """Insertion helper"""
70         if value < node.value:
71             if node.left is not None:
72                 self._insert(value, node.left)
73             else:
74                 node.left = Node(value)
75                 self.count += 1
76                 self._on_insert(node, node.left)
77         else:
78             if node.right is not None:
79                 self._insert(value, node.right)
80             else:
81                 node.right = Node(value)
82                 self.count += 1
83                 self._on_insert(node, node.right)
84
85     def __getitem__(self, value: Any) -> Optional[Node]:
86         """
87         Find an item in the tree and return its Node.  Returns
88         None if the item is not in the tree.
89
90         >>> t = BinarySearchTree()
91         >>> t[99]
92
93         >>> t.insert(10)
94         >>> t.insert(20)
95         >>> t.insert(5)
96         >>> t[10].value
97         10
98
99         >>> t[99]
100
101         """
102         if self.root is not None:
103             return self._find(value, self.root)
104         return None
105
106     def _find(self, value: Any, node: Node) -> Optional[Node]:
107         """Find helper"""
108         if value == node.value:
109             return node
110         elif value < node.value and node.left is not None:
111             return self._find(value, node.left)
112         elif value > node.value and node.right is not None:
113             return self._find(value, node.right)
114         return None
115
116     def _parent_path(
117         self, current: Optional[Node], target: Node
118     ) -> List[Optional[Node]]:
119         """Internal helper"""
120         if current is None:
121             return [None]
122         ret: List[Optional[Node]] = [current]
123         if target.value == current.value:
124             return ret
125         elif target.value < current.value:
126             ret.extend(self._parent_path(current.left, target))
127             return ret
128         else:
129             assert target.value > current.value
130             ret.extend(self._parent_path(current.right, target))
131             return ret
132
133     def parent_path(self, node: Node) -> List[Optional[Node]]:
134         """Get a node's parent path.
135
136         Args:
137             node: the node to check
138
139         Returns:
140             a list of nodes representing the path from
141             the tree's root to the node.
142
143         .. note::
144
145             If the node does not exist in the tree, the last element
146             on the path will be None but the path will indicate the
147             ancestor path of that node were it to be inserted.
148
149         >>> t = BinarySearchTree()
150         >>> t.insert(50)
151         >>> t.insert(75)
152         >>> t.insert(25)
153         >>> t.insert(12)
154         >>> t.insert(33)
155         >>> t.insert(4)
156         >>> t.insert(88)
157         >>> t
158         50
159         ├──25
160         │  ├──12
161         │  │  └──4
162         │  └──33
163         └──75
164            └──88
165
166         >>> n = t[4]
167         >>> for x in t.parent_path(n):
168         ...     print(x.value)
169         50
170         25
171         12
172         4
173
174         >>> del t[4]
175         >>> for x in t.parent_path(n):
176         ...     if x is not None:
177         ...         print(x.value)
178         ...     else:
179         ...         print(x)
180         50
181         25
182         12
183         None
184
185         """
186         return self._parent_path(self.root, node)
187
188     def __delitem__(self, value: Any) -> bool:
189         """
190         Delete an item from the tree and preserve the BST property.
191
192         Args:
193             value: the value of the node to be deleted.
194
195         Returns:
196             True if the value was found and its associated node was
197             successfully deleted and False otherwise.
198
199         >>> t = BinarySearchTree()
200         >>> t.insert(50)
201         >>> t.insert(75)
202         >>> t.insert(25)
203         >>> t.insert(66)
204         >>> t.insert(22)
205         >>> t.insert(13)
206         >>> t.insert(85)
207         >>> t
208         50
209         ├──25
210         │  └──22
211         │     └──13
212         └──75
213            ├──66
214            └──85
215
216         >>> for value in t.iterate_inorder():
217         ...     print(value)
218         13
219         22
220         25
221         50
222         66
223         75
224         85
225
226         >>> del t[22]  # Note: bool result is discarded
227
228         >>> for value in t.iterate_inorder():
229         ...     print(value)
230         13
231         25
232         50
233         66
234         75
235         85
236
237         >>> t.__delitem__(13)
238         True
239         >>> for value in t.iterate_inorder():
240         ...     print(value)
241         25
242         50
243         66
244         75
245         85
246
247         >>> t.__delitem__(75)
248         True
249         >>> for value in t.iterate_inorder():
250         ...     print(value)
251         25
252         50
253         66
254         85
255         >>> t
256         50
257         ├──25
258         └──85
259            └──66
260
261         >>> t.__delitem__(99)
262         False
263
264         """
265         if self.root is not None:
266             ret = self._delete(value, None, self.root)
267             if ret:
268                 self.count -= 1
269                 if self.count == 0:
270                     self.root = None
271             return ret
272         return False
273
274     def _on_delete(self, parent: Optional[Node], deleted: Node) -> None:
275         """This is called just before deleted is deleted --
276         i.e. before the tree changes."""
277         pass
278
279     def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool:
280         """Delete helper"""
281         if node.value == value:
282
283             # Deleting a leaf node
284             if node.left is None and node.right is None:
285                 self._on_delete(parent, node)
286                 if parent is not None:
287                     if parent.left == node:
288                         parent.left = None
289                     else:
290                         assert parent.right == node
291                         parent.right = None
292                 return True
293
294             # Node only has a right.
295             elif node.left is None:
296                 assert node.right is not None
297                 self._on_delete(parent, node)
298                 if parent is not None:
299                     if parent.left == node:
300                         parent.left = node.right
301                     else:
302                         assert parent.right == node
303                         parent.right = node.right
304                 return True
305
306             # Node only has a left.
307             elif node.right is None:
308                 assert node.left is not None
309                 self._on_delete(parent, node)
310                 if parent is not None:
311                     if parent.left == node:
312                         parent.left = node.left
313                     else:
314                         assert parent.right == node
315                         parent.right = node.left
316                 return True
317
318             # Node has both a left and right.
319             else:
320                 assert node.left is not None and node.right is not None
321                 descendent = node.right
322                 while descendent.left is not None:
323                     descendent = descendent.left
324                 node.value = descendent.value
325                 return self._delete(node.value, node, node.right)
326         elif value < node.value and node.left is not None:
327             return self._delete(value, node, node.left)
328         elif value > node.value and node.right is not None:
329             return self._delete(value, node, node.right)
330         return False
331
332     def __len__(self):
333         """
334         Returns:
335             The count of items in the tree.
336
337         >>> t = BinarySearchTree()
338         >>> len(t)
339         0
340         >>> t.insert(50)
341         >>> len(t)
342         1
343         >>> t.__delitem__(50)
344         True
345         >>> len(t)
346         0
347         >>> t.insert(75)
348         >>> t.insert(25)
349         >>> t.insert(66)
350         >>> t.insert(22)
351         >>> t.insert(13)
352         >>> t.insert(85)
353         >>> len(t)
354         6
355
356         """
357         return self.count
358
359     def __contains__(self, value: Any) -> bool:
360         """
361         Returns:
362             True if the item is in the tree; False otherwise.
363         """
364         return self.__getitem__(value) is not None
365
366     def _iterate_preorder(self, node: Node):
367         yield node.value
368         if node.left is not None:
369             yield from self._iterate_preorder(node.left)
370         if node.right is not None:
371             yield from self._iterate_preorder(node.right)
372
373     def _iterate_inorder(self, node: Node):
374         if node.left is not None:
375             yield from self._iterate_inorder(node.left)
376         yield node.value
377         if node.right is not None:
378             yield from self._iterate_inorder(node.right)
379
380     def _iterate_postorder(self, node: Node):
381         if node.left is not None:
382             yield from self._iterate_postorder(node.left)
383         if node.right is not None:
384             yield from self._iterate_postorder(node.right)
385         yield node.value
386
387     def iterate_preorder(self):
388         """
389         Returns:
390             A Generator that yields the tree's items in a
391             preorder traversal sequence.
392
393         >>> t = BinarySearchTree()
394         >>> t.insert(50)
395         >>> t.insert(75)
396         >>> t.insert(25)
397         >>> t.insert(66)
398         >>> t.insert(22)
399         >>> t.insert(13)
400
401         >>> for value in t.iterate_preorder():
402         ...     print(value)
403         50
404         25
405         22
406         13
407         75
408         66
409
410         """
411         if self.root is not None:
412             yield from self._iterate_preorder(self.root)
413
414     def iterate_inorder(self):
415         """
416         Returns:
417             A Generator that yield the tree's items in a preorder
418             traversal sequence.
419
420         >>> t = BinarySearchTree()
421         >>> t.insert(50)
422         >>> t.insert(75)
423         >>> t.insert(25)
424         >>> t.insert(66)
425         >>> t.insert(22)
426         >>> t.insert(13)
427         >>> t.insert(24)
428         >>> t
429         50
430         ├──25
431         │  └──22
432         │     ├──13
433         │     └──24
434         └──75
435            └──66
436
437         >>> for value in t.iterate_inorder():
438         ...     print(value)
439         13
440         22
441         24
442         25
443         50
444         66
445         75
446
447         """
448         if self.root is not None:
449             yield from self._iterate_inorder(self.root)
450
451     def iterate_postorder(self):
452         """
453         Returns:
454             A Generator that yield the tree's items in a preorder
455             traversal sequence.
456
457         >>> t = BinarySearchTree()
458         >>> t.insert(50)
459         >>> t.insert(75)
460         >>> t.insert(25)
461         >>> t.insert(66)
462         >>> t.insert(22)
463         >>> t.insert(13)
464
465         >>> for value in t.iterate_postorder():
466         ...     print(value)
467         13
468         22
469         25
470         66
471         75
472         50
473
474         """
475         if self.root is not None:
476             yield from self._iterate_postorder(self.root)
477
478     def _iterate_leaves(self, node: Node):
479         if node.left is not None:
480             yield from self._iterate_leaves(node.left)
481         if node.right is not None:
482             yield from self._iterate_leaves(node.right)
483         if node.left is None and node.right is None:
484             yield node.value
485
486     def iterate_leaves(self):
487         """
488         Returns:
489             A Gemerator that yielde only the leaf nodes in the
490             tree.
491
492         >>> t = BinarySearchTree()
493         >>> t.insert(50)
494         >>> t.insert(75)
495         >>> t.insert(25)
496         >>> t.insert(66)
497         >>> t.insert(22)
498         >>> t.insert(13)
499
500         >>> for value in t.iterate_leaves():
501         ...     print(value)
502         13
503         66
504
505         """
506         if self.root is not None:
507             yield from self._iterate_leaves(self.root)
508
509     def _iterate_by_depth(self, node: Node, depth: int):
510         if depth == 0:
511             yield node.value
512         else:
513             assert depth > 0
514             if node.left is not None:
515                 yield from self._iterate_by_depth(node.left, depth - 1)
516             if node.right is not None:
517                 yield from self._iterate_by_depth(node.right, depth - 1)
518
519     def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
520         """
521         Args:
522             depth: the desired depth
523
524         Returns:
525             A Generator that yields nodes at the prescribed depth in
526             the tree.
527
528         >>> t = BinarySearchTree()
529         >>> t.insert(50)
530         >>> t.insert(75)
531         >>> t.insert(25)
532         >>> t.insert(66)
533         >>> t.insert(22)
534         >>> t.insert(13)
535
536         >>> for value in t.iterate_nodes_by_depth(2):
537         ...     print(value)
538         22
539         66
540
541         >>> for value in t.iterate_nodes_by_depth(3):
542         ...     print(value)
543         13
544
545         """
546         if self.root is not None:
547             yield from self._iterate_by_depth(self.root, depth)
548
549     def get_next_node(self, node: Node) -> Node:
550         """
551         Args:
552             node: the node whose next greater successor is desired
553
554         Returns:
555             Given a tree node, returns the next greater node in the tree.
556
557         >>> t = BinarySearchTree()
558         >>> t.insert(50)
559         >>> t.insert(75)
560         >>> t.insert(25)
561         >>> t.insert(66)
562         >>> t.insert(22)
563         >>> t.insert(13)
564         >>> t.insert(23)
565         >>> t
566         50
567         ├──25
568         │  └──22
569         │     ├──13
570         │     └──23
571         └──75
572            └──66
573
574         >>> n = t[23]
575         >>> t.get_next_node(n).value
576         25
577
578         >>> n = t[50]
579         >>> t.get_next_node(n).value
580         66
581
582         """
583         if node.right is not None:
584             x = node.right
585             while x.left is not None:
586                 x = x.left
587             return x
588
589         path = self.parent_path(node)
590         assert path[-1] is not None
591         assert path[-1] == node
592         path = path[:-1]
593         path.reverse()
594         for ancestor in path:
595             assert ancestor is not None
596             if node != ancestor.right:
597                 return ancestor
598             node = ancestor
599         raise Exception()
600
601     def _depth(self, node: Node, sofar: int) -> int:
602         depth_left = sofar + 1
603         depth_right = sofar + 1
604         if node.left is not None:
605             depth_left = self._depth(node.left, sofar + 1)
606         if node.right is not None:
607             depth_right = self._depth(node.right, sofar + 1)
608         return max(depth_left, depth_right)
609
610     def depth(self) -> int:
611         """
612         Returns:
613             The max height (depth) of the tree in plies (edge distance
614             from root).
615
616         >>> t = BinarySearchTree()
617         >>> t.depth()
618         0
619
620         >>> t.insert(50)
621         >>> t.depth()
622         1
623
624         >>> t.insert(65)
625         >>> t.depth()
626         2
627
628         >>> t.insert(33)
629         >>> t.depth()
630         2
631
632         >>> t.insert(2)
633         >>> t.insert(1)
634         >>> t.depth()
635         4
636
637         """
638         if self.root is None:
639             return 0
640         return self._depth(self.root, 0)
641
642     def height(self) -> int:
643         """Returns the height (i.e. max depth) of the tree"""
644         return self.depth()
645
646     def repr_traverse(
647         self,
648         padding: str,
649         pointer: str,
650         node: Optional[Node],
651         has_right_sibling: bool,
652     ) -> str:
653         if node is not None:
654             viz = f"\n{padding}{pointer}{node.value}"
655             if has_right_sibling:
656                 padding += "│  "
657             else:
658                 padding += "   "
659
660             pointer_right = "└──"
661             if node.right is not None:
662                 pointer_left = "├──"
663             else:
664                 pointer_left = "└──"
665
666             viz += self.repr_traverse(
667                 padding, pointer_left, node.left, node.right is not None
668             )
669             viz += self.repr_traverse(padding, pointer_right, node.right, False)
670             return viz
671         return ""
672
673     def __repr__(self):
674         """
675         Returns:
676             An ASCII string representation of the tree.
677
678         >>> t = BinarySearchTree()
679         >>> t.insert(50)
680         >>> t.insert(25)
681         >>> t.insert(75)
682         >>> t.insert(12)
683         >>> t.insert(33)
684         >>> t.insert(88)
685         >>> t.insert(55)
686         >>> t
687         50
688         ├──25
689         │  ├──12
690         │  └──33
691         └──75
692            ├──55
693            └──88
694         """
695         if self.root is None:
696             return ""
697
698         ret = f"{self.root.value}"
699         pointer_right = "└──"
700         if self.root.right is None:
701             pointer_left = "└──"
702         else:
703             pointer_left = "├──"
704
705         ret += self.repr_traverse(
706             "", pointer_left, self.root.left, self.root.left is not None
707         )
708         ret += self.repr_traverse("", pointer_right, self.root.right, False)
709         return ret
710
711
712 if __name__ == "__main__":
713     import doctest
714
715     doctest.testmod()