Improve docs on bst.py again.
[pyutils.git] / src / pyutils / collectionz / bst.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2023, Scott Gasch
4
5 """A binary search tree implementation."""
6
7 from abc import abstractmethod
8 from typing import Any, Generator, List, Optional, Protocol
9
10
11 class Comparable(Protocol):
12     """Anything that implements basic comparison methods such that it
13     can be compared to other instances of the same type.
14
15     Check out :meth:`functools.total_ordering`
16     (https://docs.python.org/3/library/functools.html#functools.total_ordering)
17     for an easy way to make your type comparable.
18     """
19
20     @abstractmethod
21     def __lt__(self, other: Any) -> bool:
22         ...
23
24     @abstractmethod
25     def __le__(self, other: Any) -> bool:
26         ...
27
28     @abstractmethod
29     def __eq__(self, other: Any) -> bool:
30         ...
31
32
33 class Node:
34     def __init__(self, value: Comparable) -> None:
35         """A BST node.  Just a left and right reference along with a
36         value.  Note that value can be anything as long as it
37         is :class:`Comparable` with other instances of itself.
38
39         Args:
40             value: a reference to the value of the node.  Must be
41                 :class:`Comparable` to other values.
42
43         """
44         self.left: Optional[Node] = None
45         self.right: Optional[Node] = None
46         self.value: Comparable = value
47
48
49 class BinarySearchTree(object):
50     def __init__(self):
51         self.root = None
52         self.count = 0
53         self.traverse = None
54
55     def get_root(self) -> Optional[Node]:
56         """
57         Returns:
58             The root of the BST
59         """
60
61         return self.root
62
63     def _on_insert(self, parent: Optional[Node], new: Node) -> None:
64         """This is called immediately _after_ a new node is inserted."""
65         pass
66
67     def insert(self, value: Comparable) -> None:
68         """
69         Insert something into the tree.
70
71         Args:
72             value: the value to be inserted.
73
74         >>> t = BinarySearchTree()
75         >>> t.insert(10)
76         >>> t.insert(20)
77         >>> t.insert(5)
78         >>> len(t)
79         3
80
81         >>> t.get_root().value
82         10
83
84         """
85         if self.root is None:
86             self.root = Node(value)
87             self.count = 1
88             self._on_insert(None, self.root)
89         else:
90             self._insert(value, self.root)
91
92     def _insert(self, value: Comparable, node: Node):
93         """Insertion helper"""
94         if value < node.value:
95             if node.left is not None:
96                 self._insert(value, node.left)
97             else:
98                 node.left = Node(value)
99                 self.count += 1
100                 self._on_insert(node, node.left)
101         else:
102             if node.right is not None:
103                 self._insert(value, node.right)
104             else:
105                 node.right = Node(value)
106                 self.count += 1
107                 self._on_insert(node, node.right)
108
109     def __getitem__(self, value: Comparable) -> Optional[Node]:
110         """
111         Find an item in the tree and return its Node.  Returns
112         None if the item is not in the tree.
113
114         >>> t = BinarySearchTree()
115         >>> t[99]
116
117         >>> t.insert(10)
118         >>> t.insert(20)
119         >>> t.insert(5)
120         >>> t[10].value
121         10
122
123         >>> t[99]
124
125         """
126         if self.root is not None:
127             return self._find_exact(value, self.root)
128         return None
129
130     def _find_exact(self, target: Comparable, node: Node) -> Optional[Node]:
131         """Recursively traverse the tree looking for a node with the
132         target value.  Return that node if it exists, otherwise return
133         None."""
134
135         if target == node.value:
136             return node
137         elif target < node.value and node.left is not None:
138             return self._find_exact(target, node.left)
139         elif target > node.value and node.right is not None:
140             return self._find_exact(target, node.right)
141         return None
142
143     def _find_lowest_node_less_than_or_equal_to(
144         self, target: Comparable, node: Optional[Node]
145     ) -> Optional[Node]:
146         """Find helper that returns the lowest node that is less
147         than or equal to the target value.  Returns None if target is
148         lower than the lowest node in the tree.
149
150         >>> t = BinarySearchTree()
151         >>> t.insert(50)
152         >>> t.insert(75)
153         >>> t.insert(25)
154         >>> t.insert(66)
155         >>> t.insert(22)
156         >>> t.insert(13)
157         >>> t.insert(85)
158         >>> t
159         50
160         ├──25
161         │  └──22
162         │     └──13
163         └──75
164            ├──66
165            └──85
166
167         >>> t._find_lowest_node_less_than_or_equal_to(48, t.root).value
168         25
169         >>> t._find_lowest_node_less_than_or_equal_to(55, t.root).value
170         50
171         >>> t._find_lowest_node_less_than_or_equal_to(100, t.root).value
172         85
173         >>> t._find_lowest_node_less_than_or_equal_to(24, t.root).value
174         22
175         >>> t._find_lowest_node_less_than_or_equal_to(20, t.root).value
176         13
177         >>> t._find_lowest_node_less_than_or_equal_to(72, t.root).value
178         66
179         >>> t._find_lowest_node_less_than_or_equal_to(78, t.root).value
180         75
181         >>> t._find_lowest_node_less_than_or_equal_to(12, t.root) is None
182         True
183
184         """
185
186         if not node:
187             return None
188
189         if target == node.value:
190             return node
191
192         elif target > node.value:
193             if below := self._find_lowest_node_less_than_or_equal_to(
194                 target, node.right
195             ):
196                 return below
197             else:
198                 return node
199
200         else:
201             return self._find_lowest_node_less_than_or_equal_to(target, node.left)
202
203     def _find_lowest_node_greater_than_or_equal_to(
204         self, target: Comparable, node: Optional[Node]
205     ) -> Optional[Node]:
206         """Find helper that returns the lowest node that is greater
207         than or equal to the target value.  Returns None if target is
208         higher than the greatest node in the tree.
209
210         >>> t = BinarySearchTree()
211         >>> t.insert(50)
212         >>> t.insert(75)
213         >>> t.insert(25)
214         >>> t.insert(66)
215         >>> t.insert(22)
216         >>> t.insert(13)
217         >>> t.insert(85)
218         >>> t
219         50
220         ├──25
221         │  └──22
222         │     └──13
223         └──75
224            ├──66
225            └──85
226
227         >>> t._find_lowest_node_greater_than_or_equal_to(48, t.root).value
228         50
229         >>> t._find_lowest_node_greater_than_or_equal_to(55, t.root).value
230         66
231         >>> t._find_lowest_node_greater_than_or_equal_to(1, t.root).value
232         13
233         >>> t._find_lowest_node_greater_than_or_equal_to(24, t.root).value
234         25
235         >>> t._find_lowest_node_greater_than_or_equal_to(20, t.root).value
236         22
237         >>> t._find_lowest_node_greater_than_or_equal_to(72, t.root).value
238         75
239         >>> t._find_lowest_node_greater_than_or_equal_to(78, t.root).value
240         85
241         >>> t._find_lowest_node_greater_than_or_equal_to(95, t.root) is None
242         True
243
244         """
245
246         if not node:
247             return None
248
249         if target == node.value:
250             return node
251
252         elif target > node.value:
253             return self._find_lowest_node_greater_than_or_equal_to(target, node.right)
254
255         # If target < this node's value, either this node is the
256         # answer or the answer is in this node's left subtree.
257         else:
258             if below := self._find_lowest_node_greater_than_or_equal_to(
259                 target, node.left
260             ):
261                 return below
262             else:
263                 return node
264
265     def _parent_path(
266         self, current: Optional[Node], target: Node
267     ) -> List[Optional[Node]]:
268         """Internal helper"""
269         if current is None:
270             return [None]
271         ret: List[Optional[Node]] = [current]
272         if target.value == current.value:
273             return ret
274         elif target.value < current.value:
275             ret.extend(self._parent_path(current.left, target))
276             return ret
277         else:
278             assert target.value > current.value
279             ret.extend(self._parent_path(current.right, target))
280             return ret
281
282     def parent_path(self, node: Node) -> List[Optional[Node]]:
283         """Get a node's parent path.
284
285         Args:
286             node: the node to check
287
288         Returns:
289             a list of nodes representing the path from
290             the tree's root to the node.
291
292         .. note::
293
294             If the node does not exist in the tree, the last element
295             on the path will be None but the path will indicate the
296             ancestor path of that node were it to be inserted.
297
298         >>> t = BinarySearchTree()
299         >>> t.insert(50)
300         >>> t.insert(75)
301         >>> t.insert(25)
302         >>> t.insert(12)
303         >>> t.insert(33)
304         >>> t.insert(4)
305         >>> t.insert(88)
306         >>> t
307         50
308         ├──25
309         │  ├──12
310         │  │  └──4
311         │  └──33
312         └──75
313            └──88
314
315         >>> n = t[4]
316         >>> for x in t.parent_path(n):
317         ...     print(x.value)
318         50
319         25
320         12
321         4
322
323         >>> del t[4]
324         >>> for x in t.parent_path(n):
325         ...     if x is not None:
326         ...         print(x.value)
327         ...     else:
328         ...         print(x)
329         50
330         25
331         12
332         None
333
334         """
335         return self._parent_path(self.root, node)
336
337     def __delitem__(self, value: Comparable) -> bool:
338         """
339         Delete an item from the tree and preserve the BST property.
340
341         Args:
342             value: the value of the node to be deleted.
343
344         Returns:
345             True if the value was found and its associated node was
346             successfully deleted and False otherwise.
347
348         >>> t = BinarySearchTree()
349         >>> t.insert(50)
350         >>> t.insert(75)
351         >>> t.insert(25)
352         >>> t.insert(66)
353         >>> t.insert(22)
354         >>> t.insert(13)
355         >>> t.insert(85)
356         >>> t
357         50
358         ├──25
359         │  └──22
360         │     └──13
361         └──75
362            ├──66
363            └──85
364
365         >>> for value in t.iterate_inorder():
366         ...     print(value)
367         13
368         22
369         25
370         50
371         66
372         75
373         85
374
375         >>> del t[22]  # Note: bool result is discarded
376
377         >>> for value in t.iterate_inorder():
378         ...     print(value)
379         13
380         25
381         50
382         66
383         75
384         85
385
386         >>> t.__delitem__(13)
387         True
388         >>> for value in t.iterate_inorder():
389         ...     print(value)
390         25
391         50
392         66
393         75
394         85
395
396         >>> t.__delitem__(75)
397         True
398         >>> for value in t.iterate_inorder():
399         ...     print(value)
400         25
401         50
402         66
403         85
404         >>> t
405         50
406         ├──25
407         └──85
408            └──66
409
410         >>> t.__delitem__(99)
411         False
412
413         """
414         if self.root is not None:
415             ret = self._delete(value, None, self.root)
416             if ret:
417                 self.count -= 1
418                 if self.count == 0:
419                     self.root = None
420             return ret
421         return False
422
423     def _on_delete(self, parent: Optional[Node], deleted: Node) -> None:
424         """This is called just after deleted was deleted from the tree"""
425         pass
426
427     def _delete(self, value: Comparable, parent: Optional[Node], node: Node) -> bool:
428         """Delete helper"""
429         if node.value == value:
430
431             # Deleting a leaf node
432             if node.left is None and node.right is None:
433                 if parent is not None:
434                     if parent.left == node:
435                         parent.left = None
436                     else:
437                         assert parent.right == node
438                         parent.right = None
439                 self._on_delete(parent, node)
440                 return True
441
442             # Node only has a right.
443             elif node.left is None:
444                 assert node.right is not None
445                 if parent is not None:
446                     if parent.left == node:
447                         parent.left = node.right
448                     else:
449                         assert parent.right == node
450                         parent.right = node.right
451                 self._on_delete(parent, node)
452                 return True
453
454             # Node only has a left.
455             elif node.right is None:
456                 assert node.left is not None
457                 if parent is not None:
458                     if parent.left == node:
459                         parent.left = node.left
460                     else:
461                         assert parent.right == node
462                         parent.right = node.left
463                 self._on_delete(parent, node)
464                 return True
465
466             # Node has both a left and right.
467             else:
468                 assert node.left is not None and node.right is not None
469                 descendent = node.right
470                 while descendent.left is not None:
471                     descendent = descendent.left
472                 node.value = descendent.value
473                 return self._delete(node.value, node, node.right)
474         elif value < node.value and node.left is not None:
475             return self._delete(value, node, node.left)
476         elif value > node.value and node.right is not None:
477             return self._delete(value, node, node.right)
478         return False
479
480     def __len__(self):
481         """
482         Returns:
483             The count of items in the tree.
484
485         >>> t = BinarySearchTree()
486         >>> len(t)
487         0
488         >>> t.insert(50)
489         >>> len(t)
490         1
491         >>> t.__delitem__(50)
492         True
493         >>> len(t)
494         0
495         >>> t.insert(75)
496         >>> t.insert(25)
497         >>> t.insert(66)
498         >>> t.insert(22)
499         >>> t.insert(13)
500         >>> t.insert(85)
501         >>> len(t)
502         6
503
504         """
505         return self.count
506
507     def __contains__(self, value: Comparable) -> bool:
508         """
509         Returns:
510             True if the item is in the tree; False otherwise.
511         """
512         return self.__getitem__(value) is not None
513
514     def _iterate_preorder(self, node: Node):
515         yield node.value
516         if node.left is not None:
517             yield from self._iterate_preorder(node.left)
518         if node.right is not None:
519             yield from self._iterate_preorder(node.right)
520
521     def _iterate_inorder(self, node: Node):
522         if node.left is not None:
523             yield from self._iterate_inorder(node.left)
524         yield node.value
525         if node.right is not None:
526             yield from self._iterate_inorder(node.right)
527
528     def _iterate_postorder(self, node: Node):
529         if node.left is not None:
530             yield from self._iterate_postorder(node.left)
531         if node.right is not None:
532             yield from self._iterate_postorder(node.right)
533         yield node.value
534
535     def iterate_preorder(self):
536         """
537         Returns:
538             A Generator that yields the tree's items in a
539             preorder traversal sequence.
540
541         >>> t = BinarySearchTree()
542         >>> t.insert(50)
543         >>> t.insert(75)
544         >>> t.insert(25)
545         >>> t.insert(66)
546         >>> t.insert(22)
547         >>> t.insert(13)
548
549         >>> for value in t.iterate_preorder():
550         ...     print(value)
551         50
552         25
553         22
554         13
555         75
556         66
557
558         """
559         if self.root is not None:
560             yield from self._iterate_preorder(self.root)
561
562     def iterate_inorder(self):
563         """
564         Returns:
565             A Generator that yield the tree's items in a preorder
566             traversal sequence.
567
568         >>> t = BinarySearchTree()
569         >>> t.insert(50)
570         >>> t.insert(75)
571         >>> t.insert(25)
572         >>> t.insert(66)
573         >>> t.insert(22)
574         >>> t.insert(13)
575         >>> t.insert(24)
576         >>> t
577         50
578         ├──25
579         │  └──22
580         │     ├──13
581         │     └──24
582         └──75
583            └──66
584
585         >>> for value in t.iterate_inorder():
586         ...     print(value)
587         13
588         22
589         24
590         25
591         50
592         66
593         75
594
595         """
596         if self.root is not None:
597             yield from self._iterate_inorder(self.root)
598
599     def iterate_postorder(self):
600         """
601         Returns:
602             A Generator that yield the tree's items in a preorder
603             traversal sequence.
604
605         >>> t = BinarySearchTree()
606         >>> t.insert(50)
607         >>> t.insert(75)
608         >>> t.insert(25)
609         >>> t.insert(66)
610         >>> t.insert(22)
611         >>> t.insert(13)
612
613         >>> for value in t.iterate_postorder():
614         ...     print(value)
615         13
616         22
617         25
618         66
619         75
620         50
621
622         """
623         if self.root is not None:
624             yield from self._iterate_postorder(self.root)
625
626     def _iterate_leaves(self, node: Node):
627         if node.left is not None:
628             yield from self._iterate_leaves(node.left)
629         if node.right is not None:
630             yield from self._iterate_leaves(node.right)
631         if node.left is None and node.right is None:
632             yield node.value
633
634     def iterate_leaves(self):
635         """
636         Returns:
637             A Gemerator that yielde only the leaf nodes in the
638             tree.
639
640         >>> t = BinarySearchTree()
641         >>> t.insert(50)
642         >>> t.insert(75)
643         >>> t.insert(25)
644         >>> t.insert(66)
645         >>> t.insert(22)
646         >>> t.insert(13)
647
648         >>> for value in t.iterate_leaves():
649         ...     print(value)
650         13
651         66
652
653         """
654         if self.root is not None:
655             yield from self._iterate_leaves(self.root)
656
657     def _iterate_by_depth(self, node: Node, depth: int):
658         if depth == 0:
659             yield node.value
660         else:
661             assert depth > 0
662             if node.left is not None:
663                 yield from self._iterate_by_depth(node.left, depth - 1)
664             if node.right is not None:
665                 yield from self._iterate_by_depth(node.right, depth - 1)
666
667     def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
668         """
669         Args:
670             depth: the desired depth
671
672         Returns:
673             A Generator that yields nodes at the prescribed depth in
674             the tree.
675
676         >>> t = BinarySearchTree()
677         >>> t.insert(50)
678         >>> t.insert(75)
679         >>> t.insert(25)
680         >>> t.insert(66)
681         >>> t.insert(22)
682         >>> t.insert(13)
683
684         >>> for value in t.iterate_nodes_by_depth(2):
685         ...     print(value)
686         22
687         66
688
689         >>> for value in t.iterate_nodes_by_depth(3):
690         ...     print(value)
691         13
692
693         """
694         if self.root is not None:
695             yield from self._iterate_by_depth(self.root, depth)
696
697     def get_next_node(self, node: Node) -> Optional[Node]:
698         """
699         Args:
700             node: the node whose next greater successor is desired
701
702         Returns:
703             Given a tree node, returns the next greater node in the tree.
704             If the given node is the greatest node in the tree, returns None.
705
706         >>> t = BinarySearchTree()
707         >>> t.insert(50)
708         >>> t.insert(75)
709         >>> t.insert(25)
710         >>> t.insert(66)
711         >>> t.insert(22)
712         >>> t.insert(13)
713         >>> t.insert(23)
714         >>> t
715         50
716         ├──25
717         │  └──22
718         │     ├──13
719         │     └──23
720         └──75
721            └──66
722
723         >>> n = t[23]
724         >>> t.get_next_node(n).value
725         25
726
727         >>> n = t[50]
728         >>> t.get_next_node(n).value
729         66
730
731         >>> n = t[75]
732         >>> t.get_next_node(n) is None
733         True
734
735         """
736         if node.right is not None:
737             x = node.right
738             while x.left is not None:
739                 x = x.left
740             return x
741
742         path = self.parent_path(node)
743         assert path[-1] is not None
744         assert path[-1] == node
745         path = path[:-1]
746         path.reverse()
747         for ancestor in path:
748             assert ancestor is not None
749             if node != ancestor.right:
750                 return ancestor
751             node = ancestor
752         return None
753
754     def get_nodes_in_range_inclusive(self, lower: Comparable, upper: Comparable):
755         """
756         >>> t = BinarySearchTree()
757         >>> t.insert(50)
758         >>> t.insert(75)
759         >>> t.insert(25)
760         >>> t.insert(66)
761         >>> t.insert(22)
762         >>> t.insert(13)
763         >>> t.insert(23)
764         >>> t
765         50
766         ├──25
767         │  └──22
768         │     ├──13
769         │     └──23
770         └──75
771            └──66
772
773         >>> for node in t.get_nodes_in_range_inclusive(21, 74):
774         ...     print(node.value)
775         22
776         23
777         25
778         50
779         66
780         """
781         node: Optional[Node] = self._find_lowest_node_greater_than_or_equal_to(
782             lower, self.root
783         )
784         while node:
785             if lower <= node.value <= upper:
786                 yield node
787             node = self.get_next_node(node)
788
789     def _depth(self, node: Node, sofar: int) -> int:
790         depth_left = sofar + 1
791         depth_right = sofar + 1
792         if node.left is not None:
793             depth_left = self._depth(node.left, sofar + 1)
794         if node.right is not None:
795             depth_right = self._depth(node.right, sofar + 1)
796         return max(depth_left, depth_right)
797
798     def depth(self) -> int:
799         """
800         Returns:
801             The max height (depth) of the tree in plies (edge distance
802             from root).
803
804         >>> t = BinarySearchTree()
805         >>> t.depth()
806         0
807
808         >>> t.insert(50)
809         >>> t.depth()
810         1
811
812         >>> t.insert(65)
813         >>> t.depth()
814         2
815
816         >>> t.insert(33)
817         >>> t.depth()
818         2
819
820         >>> t.insert(2)
821         >>> t.insert(1)
822         >>> t.depth()
823         4
824
825         """
826         if self.root is None:
827             return 0
828         return self._depth(self.root, 0)
829
830     def height(self) -> int:
831         """Returns the height (i.e. max depth) of the tree"""
832         return self.depth()
833
834     def repr_traverse(
835         self,
836         padding: str,
837         pointer: str,
838         node: Optional[Node],
839         has_right_sibling: bool,
840     ) -> str:
841         if node is not None:
842             viz = f"\n{padding}{pointer}{node.value}"
843             if has_right_sibling:
844                 padding += "│  "
845             else:
846                 padding += "   "
847
848             pointer_right = "└──"
849             if node.right is not None:
850                 pointer_left = "├──"
851             else:
852                 pointer_left = "└──"
853
854             viz += self.repr_traverse(
855                 padding, pointer_left, node.left, node.right is not None
856             )
857             viz += self.repr_traverse(padding, pointer_right, node.right, False)
858             return viz
859         return ""
860
861     def __repr__(self):
862         """
863         Returns:
864             An ASCII string representation of the tree.
865
866         >>> t = BinarySearchTree()
867         >>> t.insert(50)
868         >>> t.insert(25)
869         >>> t.insert(75)
870         >>> t.insert(12)
871         >>> t.insert(33)
872         >>> t.insert(88)
873         >>> t.insert(55)
874         >>> t
875         50
876         ├──25
877         │  ├──12
878         │  └──33
879         └──75
880            ├──55
881            └──88
882         """
883         if self.root is None:
884             return ""
885
886         ret = f"{self.root.value}"
887         pointer_right = "└──"
888         if self.root.right is None:
889             pointer_left = "└──"
890         else:
891             pointer_left = "├──"
892
893         ret += self.repr_traverse(
894             "", pointer_left, self.root.left, self.root.left is not None
895         )
896         ret += self.repr_traverse("", pointer_right, self.root.right, False)
897         return ret
898
899
900 if __name__ == "__main__":
901     import doctest
902
903     doctest.testmod()