Cleanup code and update docs in bst.py.
[pyutils.git] / src / pyutils / collectionz / bst.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2023, Scott Gasch
4
5 """A binary search tree implementation."""
6
7 from abc import abstractmethod
8 from typing import Any, Generator, List, Optional, Protocol
9
10
11 class Comparable(Protocol):
12     """Anything that implements basic comparison methods such that it
13     can be compared to other instances of the same type.
14
15     Check out :meth:`functools.total_ordering`
16     (https://docs.python.org/3/library/functools.html#functools.total_ordering)
17     for an easy way to make your type comparable.
18     """
19
20     @abstractmethod
21     def __lt__(self, other: Any) -> bool:
22         ...
23
24     @abstractmethod
25     def __le__(self, other: Any) -> bool:
26         ...
27
28     @abstractmethod
29     def __eq__(self, other: Any) -> bool:
30         ...
31
32
33 class Node:
34     def __init__(self, value: Comparable) -> None:
35         """A BST node.  Just a left and right reference along with a
36         value.  Note that value can be anything as long as it
37         is :class:`Comparable` with other instances of itself.
38
39         Args:
40             value: a reference to the value of the node.  Must be
41                 :class:`Comparable` to other values.
42
43         """
44         self.left: Optional[Node] = None
45         self.right: Optional[Node] = None
46         self.value: Comparable = value
47
48
49 class BinarySearchTree(object):
50     def __init__(self):
51         self.root = None
52         self.count = 0
53         self.traverse = None
54
55     def get_root(self) -> Optional[Node]:
56         """
57         Returns:
58             The root of the BST
59         """
60
61         return self.root
62
63     def _on_insert(self, parent: Optional[Node], new: Node) -> None:
64         """This is called immediately _after_ a new node is inserted."""
65         pass
66
67     def insert(self, value: Comparable) -> None:
68         """
69         Insert something into the tree in :math:`O(log_2 n)` time.
70
71         Args:
72             value: the value to be inserted.
73
74         >>> t = BinarySearchTree()
75         >>> t.insert(10)
76         >>> t.insert(20)
77         >>> t.insert(5)
78         >>> len(t)
79         3
80
81         >>> t.get_root().value
82         10
83
84         """
85         if self.root is None:
86             self.root = Node(value)
87             self.count = 1
88             self._on_insert(None, self.root)
89         else:
90             self._insert(value, self.root)
91
92     def _insert(self, value: Comparable, node: Node):
93         """Insertion helper"""
94         if value < node.value:
95             if node.left is not None:
96                 self._insert(value, node.left)
97             else:
98                 node.left = Node(value)
99                 self.count += 1
100                 self._on_insert(node, node.left)
101         else:
102             if node.right is not None:
103                 self._insert(value, node.right)
104             else:
105                 node.right = Node(value)
106                 self.count += 1
107                 self._on_insert(node, node.right)
108
109     def __getitem__(self, value: Comparable) -> Optional[Node]:
110         """
111         Find an item in the tree and return its Node in
112         :math:`O(log_2 n)` time.  Returns None if the item is not in
113         the tree.
114
115         >>> t = BinarySearchTree()
116         >>> t[99]
117
118         >>> t.insert(10)
119         >>> t.insert(20)
120         >>> t.insert(5)
121         >>> t[10].value
122         10
123
124         >>> t[99]
125
126         """
127         if self.root is not None:
128             return self._find_exact(value, self.root)
129         return None
130
131     def _find_exact(self, target: Comparable, node: Node) -> Optional[Node]:
132         """Recursively traverse the tree looking for a node with the
133         target value.  Return that node if it exists, otherwise return
134         None."""
135
136         if target == node.value:
137             return node
138         elif target < node.value and node.left is not None:
139             return self._find_exact(target, node.left)
140         elif target > node.value and node.right is not None:
141             return self._find_exact(target, node.right)
142         return None
143
144     def _find_lowest_node_less_than_or_equal_to(
145         self, target: Comparable, node: Optional[Node]
146     ) -> Optional[Node]:
147         """Find helper that returns the lowest node that is less
148         than or equal to the target value.  Returns None if target is
149         lower than the lowest node in the tree.
150
151         >>> t = BinarySearchTree()
152         >>> t.insert(50)
153         >>> t.insert(75)
154         >>> t.insert(25)
155         >>> t.insert(66)
156         >>> t.insert(22)
157         >>> t.insert(13)
158         >>> t.insert(85)
159         >>> t
160         50
161         ├──25
162         │  └──22
163         │     └──13
164         └──75
165            ├──66
166            └──85
167
168         >>> t._find_lowest_node_less_than_or_equal_to(48, t.root).value
169         25
170         >>> t._find_lowest_node_less_than_or_equal_to(55, t.root).value
171         50
172         >>> t._find_lowest_node_less_than_or_equal_to(100, t.root).value
173         85
174         >>> t._find_lowest_node_less_than_or_equal_to(24, t.root).value
175         22
176         >>> t._find_lowest_node_less_than_or_equal_to(20, t.root).value
177         13
178         >>> t._find_lowest_node_less_than_or_equal_to(72, t.root).value
179         66
180         >>> t._find_lowest_node_less_than_or_equal_to(78, t.root).value
181         75
182         >>> t._find_lowest_node_less_than_or_equal_to(12, t.root) is None
183         True
184
185         """
186
187         if not node:
188             return None
189
190         if target == node.value:
191             return node
192
193         elif target > node.value:
194             if below := self._find_lowest_node_less_than_or_equal_to(
195                 target, node.right
196             ):
197                 return below
198             else:
199                 return node
200
201         else:
202             return self._find_lowest_node_less_than_or_equal_to(target, node.left)
203
204     def _find_lowest_node_greater_than_or_equal_to(
205         self, target: Comparable, node: Optional[Node]
206     ) -> Optional[Node]:
207         """Find helper that returns the lowest node that is greater
208         than or equal to the target value.  Returns None if target is
209         higher than the greatest node in the tree.
210
211         >>> t = BinarySearchTree()
212         >>> t.insert(50)
213         >>> t.insert(75)
214         >>> t.insert(25)
215         >>> t.insert(66)
216         >>> t.insert(22)
217         >>> t.insert(13)
218         >>> t.insert(85)
219         >>> t
220         50
221         ├──25
222         │  └──22
223         │     └──13
224         └──75
225            ├──66
226            └──85
227
228         >>> t._find_lowest_node_greater_than_or_equal_to(48, t.root).value
229         50
230         >>> t._find_lowest_node_greater_than_or_equal_to(55, t.root).value
231         66
232         >>> t._find_lowest_node_greater_than_or_equal_to(1, t.root).value
233         13
234         >>> t._find_lowest_node_greater_than_or_equal_to(24, t.root).value
235         25
236         >>> t._find_lowest_node_greater_than_or_equal_to(20, t.root).value
237         22
238         >>> t._find_lowest_node_greater_than_or_equal_to(72, t.root).value
239         75
240         >>> t._find_lowest_node_greater_than_or_equal_to(78, t.root).value
241         85
242         >>> t._find_lowest_node_greater_than_or_equal_to(95, t.root) is None
243         True
244
245         """
246
247         if not node:
248             return None
249
250         if target == node.value:
251             return node
252
253         elif target > node.value:
254             return self._find_lowest_node_greater_than_or_equal_to(target, node.right)
255
256         # If target < this node's value, either this node is the
257         # answer or the answer is in this node's left subtree.
258         else:
259             if below := self._find_lowest_node_greater_than_or_equal_to(
260                 target, node.left
261             ):
262                 return below
263             else:
264                 return node
265
266     def _parent_path(
267         self, current: Optional[Node], target: Node
268     ) -> List[Optional[Node]]:
269         """Internal helper"""
270         if current is None:
271             return [None]
272         ret: List[Optional[Node]] = [current]
273         if target.value == current.value:
274             return ret
275         elif target.value < current.value:
276             ret.extend(self._parent_path(current.left, target))
277             return ret
278         else:
279             assert target.value > current.value
280             ret.extend(self._parent_path(current.right, target))
281             return ret
282
283     def parent_path(self, node: Node) -> List[Optional[Node]]:
284         """Get a node's parent path in :math:`O(log_2 n)` time.
285
286         Args:
287             node: the node whose parent path should be returned.
288
289         Returns:
290             a list of nodes representing the path from
291             the tree's root to the given node.
292
293         .. note::
294
295             If the node does not exist in the tree, the last element
296             on the path will be None but the path will indicate the
297             ancestor path of that node were it to be inserted.
298
299         >>> t = BinarySearchTree()
300         >>> t.insert(50)
301         >>> t.insert(75)
302         >>> t.insert(25)
303         >>> t.insert(12)
304         >>> t.insert(33)
305         >>> t.insert(4)
306         >>> t.insert(88)
307         >>> t
308         50
309         ├──25
310         │  ├──12
311         │  │  └──4
312         │  └──33
313         └──75
314            └──88
315
316         >>> n = t[4]
317         >>> for x in t.parent_path(n):
318         ...     print(x.value)
319         50
320         25
321         12
322         4
323
324         >>> del t[4]
325         >>> for x in t.parent_path(n):
326         ...     if x is not None:
327         ...         print(x.value)
328         ...     else:
329         ...         print(x)
330         50
331         25
332         12
333         None
334
335         """
336         return self._parent_path(self.root, node)
337
338     def __delitem__(self, value: Comparable) -> bool:
339         """
340         Delete an item from the tree and preserve the BST property in
341         :math:`O(log_2 n) time`.
342
343         Args:
344             value: the value of the node to be deleted.
345
346         Returns:
347             True if the value was found and its associated node was
348             successfully deleted and False otherwise.
349
350         >>> t = BinarySearchTree()
351         >>> t.insert(50)
352         >>> t.insert(75)
353         >>> t.insert(25)
354         >>> t.insert(66)
355         >>> t.insert(22)
356         >>> t.insert(13)
357         >>> t.insert(85)
358         >>> t
359         50
360         ├──25
361         │  └──22
362         │     └──13
363         └──75
364            ├──66
365            └──85
366
367         >>> for value in t.iterate_inorder():
368         ...     print(value)
369         13
370         22
371         25
372         50
373         66
374         75
375         85
376
377         >>> del t[22]  # Note: bool result is discarded
378
379         >>> for value in t.iterate_inorder():
380         ...     print(value)
381         13
382         25
383         50
384         66
385         75
386         85
387
388         >>> t.__delitem__(13)
389         True
390         >>> for value in t.iterate_inorder():
391         ...     print(value)
392         25
393         50
394         66
395         75
396         85
397
398         >>> t.__delitem__(75)
399         True
400         >>> for value in t.iterate_inorder():
401         ...     print(value)
402         25
403         50
404         66
405         85
406         >>> t
407         50
408         ├──25
409         └──85
410            └──66
411
412         >>> t.__delitem__(85)
413         True
414
415         >>> t.__delitem__(99)
416         False
417
418         """
419         if self.root is not None:
420             ret = self._delete(value, None, self.root)
421             if ret:
422                 self.count -= 1
423                 if self.count == 0:
424                     self.root = None
425             return ret
426         return False
427
428     def _on_delete(self, parent: Optional[Node], deleted: Node) -> None:
429         """This is called just after deleted was deleted from the tree"""
430         pass
431
432     def _delete(self, value: Comparable, parent: Optional[Node], node: Node) -> bool:
433         """Delete helper"""
434         if node.value == value:
435
436             # Deleting a leaf node
437             if node.left is None and node.right is None:
438                 if parent is not None:
439                     if parent.left == node:
440                         parent.left = None
441                     else:
442                         assert parent.right == node
443                         parent.right = None
444                 self._on_delete(parent, node)
445                 return True
446
447             # Node only has a right.
448             elif node.left is None:
449                 assert node.right is not None
450                 if parent is not None:
451                     if parent.left == node:
452                         parent.left = node.right
453                     else:
454                         assert parent.right == node
455                         parent.right = node.right
456                 self._on_delete(parent, node)
457                 return True
458
459             # Node only has a left.
460             elif node.right is None:
461                 assert node.left is not None
462                 if parent is not None:
463                     if parent.left == node:
464                         parent.left = node.left
465                     else:
466                         assert parent.right == node
467                         parent.right = node.left
468                 self._on_delete(parent, node)
469                 return True
470
471             # Node has both a left and right; get the successor node
472             # to this one and put it here (deleting the successor's
473             # old node).  Because these operations are happening only
474             # in the subtree underneath of node, I'm still calling
475             # this delete an O(log_2 n) operation in the docs.
476             else:
477                 assert node.left is not None and node.right is not None
478                 successor = self.get_next_node(node)
479                 assert successor is not None
480                 node.value = successor.value
481                 return self._delete(node.value, node, node.right)
482
483         elif value < node.value and node.left is not None:
484             return self._delete(value, node, node.left)
485         elif value > node.value and node.right is not None:
486             return self._delete(value, node, node.right)
487         return False
488
489     def __len__(self):
490         """
491         Returns:
492             The count of items in the tree in :math:`O(1)` time.
493
494         >>> t = BinarySearchTree()
495         >>> len(t)
496         0
497         >>> t.insert(50)
498         >>> len(t)
499         1
500         >>> t.__delitem__(50)
501         True
502         >>> len(t)
503         0
504         >>> t.insert(75)
505         >>> t.insert(25)
506         >>> t.insert(66)
507         >>> t.insert(22)
508         >>> t.insert(13)
509         >>> t.insert(85)
510         >>> len(t)
511         6
512
513         """
514         return self.count
515
516     def __contains__(self, value: Comparable) -> bool:
517         """
518         Returns:
519             True if the item is in the tree; False otherwise.
520         """
521         return self.__getitem__(value) is not None
522
523     def _iterate_preorder(self, node: Node):
524         yield node.value
525         if node.left is not None:
526             yield from self._iterate_preorder(node.left)
527         if node.right is not None:
528             yield from self._iterate_preorder(node.right)
529
530     def _iterate_inorder(self, node: Node):
531         if node.left is not None:
532             yield from self._iterate_inorder(node.left)
533         yield node.value
534         if node.right is not None:
535             yield from self._iterate_inorder(node.right)
536
537     def _iterate_postorder(self, node: Node):
538         if node.left is not None:
539             yield from self._iterate_postorder(node.left)
540         if node.right is not None:
541             yield from self._iterate_postorder(node.right)
542         yield node.value
543
544     def iterate_preorder(self):
545         """
546         Returns:
547             A Generator that yields the tree's items in a
548             preorder traversal sequence.
549
550         >>> t = BinarySearchTree()
551         >>> t.insert(50)
552         >>> t.insert(75)
553         >>> t.insert(25)
554         >>> t.insert(66)
555         >>> t.insert(22)
556         >>> t.insert(13)
557
558         >>> for value in t.iterate_preorder():
559         ...     print(value)
560         50
561         25
562         22
563         13
564         75
565         66
566
567         """
568         if self.root is not None:
569             yield from self._iterate_preorder(self.root)
570
571     def iterate_inorder(self):
572         """
573         Returns:
574             A Generator that yield the tree's items in a preorder
575             traversal sequence.
576
577         >>> t = BinarySearchTree()
578         >>> t.insert(50)
579         >>> t.insert(75)
580         >>> t.insert(25)
581         >>> t.insert(66)
582         >>> t.insert(22)
583         >>> t.insert(13)
584         >>> t.insert(24)
585         >>> t
586         50
587         ├──25
588         │  └──22
589         │     ├──13
590         │     └──24
591         └──75
592            └──66
593
594         >>> for value in t.iterate_inorder():
595         ...     print(value)
596         13
597         22
598         24
599         25
600         50
601         66
602         75
603
604         """
605         if self.root is not None:
606             yield from self._iterate_inorder(self.root)
607
608     def iterate_postorder(self):
609         """
610         Returns:
611             A Generator that yield the tree's items in a preorder
612             traversal sequence.
613
614         >>> t = BinarySearchTree()
615         >>> t.insert(50)
616         >>> t.insert(75)
617         >>> t.insert(25)
618         >>> t.insert(66)
619         >>> t.insert(22)
620         >>> t.insert(13)
621
622         >>> for value in t.iterate_postorder():
623         ...     print(value)
624         13
625         22
626         25
627         66
628         75
629         50
630
631         """
632         if self.root is not None:
633             yield from self._iterate_postorder(self.root)
634
635     def _iterate_leaves(self, node: Node):
636         if node.left is not None:
637             yield from self._iterate_leaves(node.left)
638         if node.right is not None:
639             yield from self._iterate_leaves(node.right)
640         if node.left is None and node.right is None:
641             yield node.value
642
643     def iterate_leaves(self):
644         """
645         Returns:
646             A Generator that yields only the leaf nodes in the
647             tree.
648
649         >>> t = BinarySearchTree()
650         >>> t.insert(50)
651         >>> t.insert(75)
652         >>> t.insert(25)
653         >>> t.insert(66)
654         >>> t.insert(22)
655         >>> t.insert(13)
656
657         >>> for value in t.iterate_leaves():
658         ...     print(value)
659         13
660         66
661
662         """
663         if self.root is not None:
664             yield from self._iterate_leaves(self.root)
665
666     def _iterate_by_depth(self, node: Node, depth: int):
667         if depth == 0:
668             yield node.value
669         else:
670             assert depth > 0
671             if node.left is not None:
672                 yield from self._iterate_by_depth(node.left, depth - 1)
673             if node.right is not None:
674                 yield from self._iterate_by_depth(node.right, depth - 1)
675
676     def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
677         """
678         Args:
679             depth: the desired depth
680
681         Returns:
682             A Generator that yields nodes at the prescribed depth in
683             the tree.
684
685         >>> t = BinarySearchTree()
686         >>> t.insert(50)
687         >>> t.insert(75)
688         >>> t.insert(25)
689         >>> t.insert(66)
690         >>> t.insert(22)
691         >>> t.insert(13)
692
693         >>> for value in t.iterate_nodes_by_depth(2):
694         ...     print(value)
695         22
696         66
697
698         >>> for value in t.iterate_nodes_by_depth(3):
699         ...     print(value)
700         13
701
702         """
703         if self.root is not None:
704             yield from self._iterate_by_depth(self.root, depth)
705
706     def get_next_node(self, node: Node) -> Optional[Node]:
707         """
708         Args:
709             node: the node whose next greater successor is desired
710
711         Returns:
712             Given a tree node, returns the next greater node in the tree.
713             If the given node is the greatest node in the tree, returns None.
714
715         >>> t = BinarySearchTree()
716         >>> t.insert(50)
717         >>> t.insert(75)
718         >>> t.insert(25)
719         >>> t.insert(66)
720         >>> t.insert(22)
721         >>> t.insert(13)
722         >>> t.insert(23)
723         >>> t
724         50
725         ├──25
726         │  └──22
727         │     ├──13
728         │     └──23
729         └──75
730            └──66
731
732         >>> n = t[23]
733         >>> t.get_next_node(n).value
734         25
735
736         >>> n = t[50]
737         >>> t.get_next_node(n).value
738         66
739
740         >>> n = t[75]
741         >>> t.get_next_node(n) is None
742         True
743
744         """
745         if node.right is not None:
746             x = node.right
747             while x.left is not None:
748                 x = x.left
749             return x
750
751         path = self.parent_path(node)
752         assert path[-1] is not None
753         assert path[-1] == node
754         path = path[:-1]
755         path.reverse()
756         for ancestor in path:
757             assert ancestor is not None
758             if node != ancestor.right:
759                 return ancestor
760             node = ancestor
761         return None
762
763     def get_nodes_in_range_inclusive(
764         self, lower: Comparable, upper: Comparable
765     ) -> Generator[Node, None, None]:
766         """
767         Args:
768             lower: the lower bound of the desired range.
769             upper: the upper bound of the desired range.
770
771         Returns:
772             Generates a sequence of nodes in the desired range.
773
774         >>> t = BinarySearchTree()
775         >>> t.insert(50)
776         >>> t.insert(75)
777         >>> t.insert(25)
778         >>> t.insert(66)
779         >>> t.insert(22)
780         >>> t.insert(13)
781         >>> t.insert(23)
782         >>> t
783         50
784         ├──25
785         │  └──22
786         │     ├──13
787         │     └──23
788         └──75
789            └──66
790
791         >>> for node in t.get_nodes_in_range_inclusive(21, 74):
792         ...     print(node.value)
793         22
794         23
795         25
796         50
797         66
798         """
799         node: Optional[Node] = self._find_lowest_node_greater_than_or_equal_to(
800             lower, self.root
801         )
802         while node:
803             if lower <= node.value <= upper:
804                 yield node
805             node = self.get_next_node(node)
806
807     def _depth(self, node: Node, sofar: int) -> int:
808         depth_left = sofar + 1
809         depth_right = sofar + 1
810         if node.left is not None:
811             depth_left = self._depth(node.left, sofar + 1)
812         if node.right is not None:
813             depth_right = self._depth(node.right, sofar + 1)
814         return max(depth_left, depth_right)
815
816     def depth(self) -> int:
817         """
818         Returns:
819             The max height (depth) of the tree in plies (edge distance
820             from root) in :math:`O(log_2 n)` time.
821
822         >>> t = BinarySearchTree()
823         >>> t.depth()
824         0
825
826         >>> t.insert(50)
827         >>> t.depth()
828         1
829
830         >>> t.insert(65)
831         >>> t.depth()
832         2
833
834         >>> t.insert(33)
835         >>> t.depth()
836         2
837
838         >>> t.insert(2)
839         >>> t.insert(1)
840         >>> t.depth()
841         4
842
843         """
844         if self.root is None:
845             return 0
846         return self._depth(self.root, 0)
847
848     def height(self) -> int:
849         """Returns the height (i.e. max depth) of the tree"""
850         return self.depth()
851
852     def repr_traverse(
853         self,
854         padding: str,
855         pointer: str,
856         node: Optional[Node],
857         has_right_sibling: bool,
858     ) -> str:
859         if node is not None:
860             viz = f"\n{padding}{pointer}{node.value}"
861             if has_right_sibling:
862                 padding += "│  "
863             else:
864                 padding += "   "
865
866             pointer_right = "└──"
867             if node.right is not None:
868                 pointer_left = "├──"
869             else:
870                 pointer_left = "└──"
871
872             viz += self.repr_traverse(
873                 padding, pointer_left, node.left, node.right is not None
874             )
875             viz += self.repr_traverse(padding, pointer_right, node.right, False)
876             return viz
877         return ""
878
879     def __repr__(self):
880         """
881         Returns:
882             An ASCII string representation of the tree.
883
884         >>> t = BinarySearchTree()
885         >>> t.insert(50)
886         >>> t.insert(25)
887         >>> t.insert(75)
888         >>> t.insert(12)
889         >>> t.insert(33)
890         >>> t.insert(88)
891         >>> t.insert(55)
892         >>> t
893         50
894         ├──25
895         │  ├──12
896         │  └──33
897         └──75
898            ├──55
899            └──88
900         """
901         if self.root is None:
902             return ""
903
904         ret = f"{self.root.value}"
905         pointer_right = "└──"
906         if self.root.right is None:
907             pointer_left = "└──"
908         else:
909             pointer_left = "├──"
910
911         ret += self.repr_traverse(
912             "", pointer_left, self.root.left, self.root.left is not None
913         )
914         ret += self.repr_traverse("", pointer_right, self.root.right, False)
915         return ret
916
917
918 if __name__ == "__main__":
919     import doctest
920
921     doctest.testmod()