3a0fae9b9b11fb84b077fd6c91ee50f7703aa813
[pyutils.git] / src / pyutils / collectionz / bst.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2023, Scott Gasch
4
5 """A binary search tree implementation."""
6
7 from abc import abstractmethod
8 from typing import Any, Generator, List, Optional, Protocol
9
10
11 class Comparable(Protocol):
12     @abstractmethod
13     def __lt__(self, other: Any) -> bool:
14         ...
15
16     @abstractmethod
17     def __le__(self, other: Any) -> bool:
18         ...
19
20     @abstractmethod
21     def __eq__(self, other: Any) -> bool:
22         ...
23
24
25 class Node:
26     def __init__(self, value: Comparable) -> None:
27         """A BST node.  Note that value can be anything as long as it
28         is comparable with other instances of itself.  Check out
29         :meth:`functools.total_ordering`
30         (https://docs.python.org/3/library/functools.html#functools.total_ordering)
31
32         Args:
33             value: a reference to the value of the node.  Must be comparable
34             to other values.
35
36         """
37         self.left: Optional[Node] = None
38         self.right: Optional[Node] = None
39         self.value: Comparable = value
40
41
42 class BinarySearchTree(object):
43     def __init__(self):
44         self.root = None
45         self.count = 0
46         self.traverse = None
47
48     def get_root(self) -> Optional[Node]:
49         """
50         Returns:
51             The root of the BST
52         """
53
54         return self.root
55
56     def _on_insert(self, parent: Optional[Node], new: Node) -> None:
57         """This is called immediately _after_ a new node is inserted."""
58         pass
59
60     def insert(self, value: Comparable) -> None:
61         """
62         Insert something into the tree.
63
64         Args:
65             value: the value to be inserted.
66
67         >>> t = BinarySearchTree()
68         >>> t.insert(10)
69         >>> t.insert(20)
70         >>> t.insert(5)
71         >>> len(t)
72         3
73
74         >>> t.get_root().value
75         10
76
77         """
78         if self.root is None:
79             self.root = Node(value)
80             self.count = 1
81             self._on_insert(None, self.root)
82         else:
83             self._insert(value, self.root)
84
85     def _insert(self, value: Comparable, node: Node):
86         """Insertion helper"""
87         if value < node.value:
88             if node.left is not None:
89                 self._insert(value, node.left)
90             else:
91                 node.left = Node(value)
92                 self.count += 1
93                 self._on_insert(node, node.left)
94         else:
95             if node.right is not None:
96                 self._insert(value, node.right)
97             else:
98                 node.right = Node(value)
99                 self.count += 1
100                 self._on_insert(node, node.right)
101
102     def __getitem__(self, value: Comparable) -> Optional[Node]:
103         """
104         Find an item in the tree and return its Node.  Returns
105         None if the item is not in the tree.
106
107         >>> t = BinarySearchTree()
108         >>> t[99]
109
110         >>> t.insert(10)
111         >>> t.insert(20)
112         >>> t.insert(5)
113         >>> t[10].value
114         10
115
116         >>> t[99]
117
118         """
119         if self.root is not None:
120             return self._find_exact(value, self.root)
121         return None
122
123     def _find_exact(self, target: Comparable, node: Node) -> Optional[Node]:
124         """Recursively traverse the tree looking for a node with the
125         target value.  Return that node if it exists, otherwise return
126         None."""
127
128         if target == node.value:
129             return node
130         elif target < node.value and node.left is not None:
131             return self._find_exact(target, node.left)
132         elif target > node.value and node.right is not None:
133             return self._find_exact(target, node.right)
134         return None
135
136     def _find_lowest_node_less_than_or_equal_to(
137         self, target: Comparable, node: Optional[Node]
138     ) -> Optional[Node]:
139         """Find helper that returns the lowest node that is less
140         than or equal to the target value.  Returns None if target is
141         lower than the lowest node in the tree.
142
143         >>> t = BinarySearchTree()
144         >>> t.insert(50)
145         >>> t.insert(75)
146         >>> t.insert(25)
147         >>> t.insert(66)
148         >>> t.insert(22)
149         >>> t.insert(13)
150         >>> t.insert(85)
151         >>> t
152         50
153         ├──25
154         │  └──22
155         │     └──13
156         └──75
157            ├──66
158            └──85
159
160         >>> t._find_lowest_node_less_than_or_equal_to(48, t.root).value
161         25
162         >>> t._find_lowest_node_less_than_or_equal_to(55, t.root).value
163         50
164         >>> t._find_lowest_node_less_than_or_equal_to(100, t.root).value
165         85
166         >>> t._find_lowest_node_less_than_or_equal_to(24, t.root).value
167         22
168         >>> t._find_lowest_node_less_than_or_equal_to(20, t.root).value
169         13
170         >>> t._find_lowest_node_less_than_or_equal_to(72, t.root).value
171         66
172         >>> t._find_lowest_node_less_than_or_equal_to(78, t.root).value
173         75
174         >>> t._find_lowest_node_less_than_or_equal_to(12, t.root) is None
175         True
176
177         """
178
179         if not node:
180             return None
181
182         if target == node.value:
183             return node
184
185         elif target > node.value:
186             if below := self._find_lowest_node_less_than_or_equal_to(
187                 target, node.right
188             ):
189                 return below
190             else:
191                 return node
192
193         else:
194             return self._find_lowest_node_less_than_or_equal_to(target, node.left)
195
196     def _find_lowest_node_greater_than_or_equal_to(
197         self, target: Comparable, node: Optional[Node]
198     ) -> Optional[Node]:
199         """Find helper that returns the lowest node that is greater
200         than or equal to the target value.  Returns None if target is
201         higher than the greatest node in the tree.
202
203         >>> t = BinarySearchTree()
204         >>> t.insert(50)
205         >>> t.insert(75)
206         >>> t.insert(25)
207         >>> t.insert(66)
208         >>> t.insert(22)
209         >>> t.insert(13)
210         >>> t.insert(85)
211         >>> t
212         50
213         ├──25
214         │  └──22
215         │     └──13
216         └──75
217            ├──66
218            └──85
219
220         >>> t._find_lowest_node_greater_than_or_equal_to(48, t.root).value
221         50
222         >>> t._find_lowest_node_greater_than_or_equal_to(55, t.root).value
223         66
224         >>> t._find_lowest_node_greater_than_or_equal_to(1, t.root).value
225         13
226         >>> t._find_lowest_node_greater_than_or_equal_to(24, t.root).value
227         25
228         >>> t._find_lowest_node_greater_than_or_equal_to(20, t.root).value
229         22
230         >>> t._find_lowest_node_greater_than_or_equal_to(72, t.root).value
231         75
232         >>> t._find_lowest_node_greater_than_or_equal_to(78, t.root).value
233         85
234         >>> t._find_lowest_node_greater_than_or_equal_to(95, t.root) is None
235         True
236
237         """
238
239         if not node:
240             return None
241
242         if target == node.value:
243             return node
244
245         elif target > node.value:
246             return self._find_lowest_node_greater_than_or_equal_to(target, node.right)
247
248         # If target < this node's value, either this node is the
249         # answer or the answer is in this node's left subtree.
250         else:
251             if below := self._find_lowest_node_greater_than_or_equal_to(
252                 target, node.left
253             ):
254                 return below
255             else:
256                 return node
257
258     def _parent_path(
259         self, current: Optional[Node], target: Node
260     ) -> List[Optional[Node]]:
261         """Internal helper"""
262         if current is None:
263             return [None]
264         ret: List[Optional[Node]] = [current]
265         if target.value == current.value:
266             return ret
267         elif target.value < current.value:
268             ret.extend(self._parent_path(current.left, target))
269             return ret
270         else:
271             assert target.value > current.value
272             ret.extend(self._parent_path(current.right, target))
273             return ret
274
275     def parent_path(self, node: Node) -> List[Optional[Node]]:
276         """Get a node's parent path.
277
278         Args:
279             node: the node to check
280
281         Returns:
282             a list of nodes representing the path from
283             the tree's root to the node.
284
285         .. note::
286
287             If the node does not exist in the tree, the last element
288             on the path will be None but the path will indicate the
289             ancestor path of that node were it to be inserted.
290
291         >>> t = BinarySearchTree()
292         >>> t.insert(50)
293         >>> t.insert(75)
294         >>> t.insert(25)
295         >>> t.insert(12)
296         >>> t.insert(33)
297         >>> t.insert(4)
298         >>> t.insert(88)
299         >>> t
300         50
301         ├──25
302         │  ├──12
303         │  │  └──4
304         │  └──33
305         └──75
306            └──88
307
308         >>> n = t[4]
309         >>> for x in t.parent_path(n):
310         ...     print(x.value)
311         50
312         25
313         12
314         4
315
316         >>> del t[4]
317         >>> for x in t.parent_path(n):
318         ...     if x is not None:
319         ...         print(x.value)
320         ...     else:
321         ...         print(x)
322         50
323         25
324         12
325         None
326
327         """
328         return self._parent_path(self.root, node)
329
330     def __delitem__(self, value: Comparable) -> bool:
331         """
332         Delete an item from the tree and preserve the BST property.
333
334         Args:
335             value: the value of the node to be deleted.
336
337         Returns:
338             True if the value was found and its associated node was
339             successfully deleted and False otherwise.
340
341         >>> t = BinarySearchTree()
342         >>> t.insert(50)
343         >>> t.insert(75)
344         >>> t.insert(25)
345         >>> t.insert(66)
346         >>> t.insert(22)
347         >>> t.insert(13)
348         >>> t.insert(85)
349         >>> t
350         50
351         ├──25
352         │  └──22
353         │     └──13
354         └──75
355            ├──66
356            └──85
357
358         >>> for value in t.iterate_inorder():
359         ...     print(value)
360         13
361         22
362         25
363         50
364         66
365         75
366         85
367
368         >>> del t[22]  # Note: bool result is discarded
369
370         >>> for value in t.iterate_inorder():
371         ...     print(value)
372         13
373         25
374         50
375         66
376         75
377         85
378
379         >>> t.__delitem__(13)
380         True
381         >>> for value in t.iterate_inorder():
382         ...     print(value)
383         25
384         50
385         66
386         75
387         85
388
389         >>> t.__delitem__(75)
390         True
391         >>> for value in t.iterate_inorder():
392         ...     print(value)
393         25
394         50
395         66
396         85
397         >>> t
398         50
399         ├──25
400         └──85
401            └──66
402
403         >>> t.__delitem__(99)
404         False
405
406         """
407         if self.root is not None:
408             ret = self._delete(value, None, self.root)
409             if ret:
410                 self.count -= 1
411                 if self.count == 0:
412                     self.root = None
413             return ret
414         return False
415
416     def _on_delete(self, parent: Optional[Node], deleted: Node) -> None:
417         """This is called just after deleted was deleted from the tree"""
418         pass
419
420     def _delete(self, value: Comparable, parent: Optional[Node], node: Node) -> bool:
421         """Delete helper"""
422         if node.value == value:
423
424             # Deleting a leaf node
425             if node.left is None and node.right is None:
426                 if parent is not None:
427                     if parent.left == node:
428                         parent.left = None
429                     else:
430                         assert parent.right == node
431                         parent.right = None
432                 self._on_delete(parent, node)
433                 return True
434
435             # Node only has a right.
436             elif node.left is None:
437                 assert node.right is not None
438                 if parent is not None:
439                     if parent.left == node:
440                         parent.left = node.right
441                     else:
442                         assert parent.right == node
443                         parent.right = node.right
444                 self._on_delete(parent, node)
445                 return True
446
447             # Node only has a left.
448             elif node.right is None:
449                 assert node.left is not None
450                 if parent is not None:
451                     if parent.left == node:
452                         parent.left = node.left
453                     else:
454                         assert parent.right == node
455                         parent.right = node.left
456                 self._on_delete(parent, node)
457                 return True
458
459             # Node has both a left and right.
460             else:
461                 assert node.left is not None and node.right is not None
462                 descendent = node.right
463                 while descendent.left is not None:
464                     descendent = descendent.left
465                 node.value = descendent.value
466                 return self._delete(node.value, node, node.right)
467         elif value < node.value and node.left is not None:
468             return self._delete(value, node, node.left)
469         elif value > node.value and node.right is not None:
470             return self._delete(value, node, node.right)
471         return False
472
473     def __len__(self):
474         """
475         Returns:
476             The count of items in the tree.
477
478         >>> t = BinarySearchTree()
479         >>> len(t)
480         0
481         >>> t.insert(50)
482         >>> len(t)
483         1
484         >>> t.__delitem__(50)
485         True
486         >>> len(t)
487         0
488         >>> t.insert(75)
489         >>> t.insert(25)
490         >>> t.insert(66)
491         >>> t.insert(22)
492         >>> t.insert(13)
493         >>> t.insert(85)
494         >>> len(t)
495         6
496
497         """
498         return self.count
499
500     def __contains__(self, value: Comparable) -> bool:
501         """
502         Returns:
503             True if the item is in the tree; False otherwise.
504         """
505         return self.__getitem__(value) is not None
506
507     def _iterate_preorder(self, node: Node):
508         yield node.value
509         if node.left is not None:
510             yield from self._iterate_preorder(node.left)
511         if node.right is not None:
512             yield from self._iterate_preorder(node.right)
513
514     def _iterate_inorder(self, node: Node):
515         if node.left is not None:
516             yield from self._iterate_inorder(node.left)
517         yield node.value
518         if node.right is not None:
519             yield from self._iterate_inorder(node.right)
520
521     def _iterate_postorder(self, node: Node):
522         if node.left is not None:
523             yield from self._iterate_postorder(node.left)
524         if node.right is not None:
525             yield from self._iterate_postorder(node.right)
526         yield node.value
527
528     def iterate_preorder(self):
529         """
530         Returns:
531             A Generator that yields the tree's items in a
532             preorder traversal sequence.
533
534         >>> t = BinarySearchTree()
535         >>> t.insert(50)
536         >>> t.insert(75)
537         >>> t.insert(25)
538         >>> t.insert(66)
539         >>> t.insert(22)
540         >>> t.insert(13)
541
542         >>> for value in t.iterate_preorder():
543         ...     print(value)
544         50
545         25
546         22
547         13
548         75
549         66
550
551         """
552         if self.root is not None:
553             yield from self._iterate_preorder(self.root)
554
555     def iterate_inorder(self):
556         """
557         Returns:
558             A Generator that yield the tree's items in a preorder
559             traversal sequence.
560
561         >>> t = BinarySearchTree()
562         >>> t.insert(50)
563         >>> t.insert(75)
564         >>> t.insert(25)
565         >>> t.insert(66)
566         >>> t.insert(22)
567         >>> t.insert(13)
568         >>> t.insert(24)
569         >>> t
570         50
571         ├──25
572         │  └──22
573         │     ├──13
574         │     └──24
575         └──75
576            └──66
577
578         >>> for value in t.iterate_inorder():
579         ...     print(value)
580         13
581         22
582         24
583         25
584         50
585         66
586         75
587
588         """
589         if self.root is not None:
590             yield from self._iterate_inorder(self.root)
591
592     def iterate_postorder(self):
593         """
594         Returns:
595             A Generator that yield the tree's items in a preorder
596             traversal sequence.
597
598         >>> t = BinarySearchTree()
599         >>> t.insert(50)
600         >>> t.insert(75)
601         >>> t.insert(25)
602         >>> t.insert(66)
603         >>> t.insert(22)
604         >>> t.insert(13)
605
606         >>> for value in t.iterate_postorder():
607         ...     print(value)
608         13
609         22
610         25
611         66
612         75
613         50
614
615         """
616         if self.root is not None:
617             yield from self._iterate_postorder(self.root)
618
619     def _iterate_leaves(self, node: Node):
620         if node.left is not None:
621             yield from self._iterate_leaves(node.left)
622         if node.right is not None:
623             yield from self._iterate_leaves(node.right)
624         if node.left is None and node.right is None:
625             yield node.value
626
627     def iterate_leaves(self):
628         """
629         Returns:
630             A Gemerator that yielde only the leaf nodes in the
631             tree.
632
633         >>> t = BinarySearchTree()
634         >>> t.insert(50)
635         >>> t.insert(75)
636         >>> t.insert(25)
637         >>> t.insert(66)
638         >>> t.insert(22)
639         >>> t.insert(13)
640
641         >>> for value in t.iterate_leaves():
642         ...     print(value)
643         13
644         66
645
646         """
647         if self.root is not None:
648             yield from self._iterate_leaves(self.root)
649
650     def _iterate_by_depth(self, node: Node, depth: int):
651         if depth == 0:
652             yield node.value
653         else:
654             assert depth > 0
655             if node.left is not None:
656                 yield from self._iterate_by_depth(node.left, depth - 1)
657             if node.right is not None:
658                 yield from self._iterate_by_depth(node.right, depth - 1)
659
660     def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
661         """
662         Args:
663             depth: the desired depth
664
665         Returns:
666             A Generator that yields nodes at the prescribed depth in
667             the tree.
668
669         >>> t = BinarySearchTree()
670         >>> t.insert(50)
671         >>> t.insert(75)
672         >>> t.insert(25)
673         >>> t.insert(66)
674         >>> t.insert(22)
675         >>> t.insert(13)
676
677         >>> for value in t.iterate_nodes_by_depth(2):
678         ...     print(value)
679         22
680         66
681
682         >>> for value in t.iterate_nodes_by_depth(3):
683         ...     print(value)
684         13
685
686         """
687         if self.root is not None:
688             yield from self._iterate_by_depth(self.root, depth)
689
690     def get_next_node(self, node: Node) -> Optional[Node]:
691         """
692         Args:
693             node: the node whose next greater successor is desired
694
695         Returns:
696             Given a tree node, returns the next greater node in the tree.
697             If the given node is the greatest node in the tree, returns None.
698
699         >>> t = BinarySearchTree()
700         >>> t.insert(50)
701         >>> t.insert(75)
702         >>> t.insert(25)
703         >>> t.insert(66)
704         >>> t.insert(22)
705         >>> t.insert(13)
706         >>> t.insert(23)
707         >>> t
708         50
709         ├──25
710         │  └──22
711         │     ├──13
712         │     └──23
713         └──75
714            └──66
715
716         >>> n = t[23]
717         >>> t.get_next_node(n).value
718         25
719
720         >>> n = t[50]
721         >>> t.get_next_node(n).value
722         66
723
724         >>> n = t[75]
725         >>> t.get_next_node(n) is None
726         True
727
728         """
729         if node.right is not None:
730             x = node.right
731             while x.left is not None:
732                 x = x.left
733             return x
734
735         path = self.parent_path(node)
736         assert path[-1] is not None
737         assert path[-1] == node
738         path = path[:-1]
739         path.reverse()
740         for ancestor in path:
741             assert ancestor is not None
742             if node != ancestor.right:
743                 return ancestor
744             node = ancestor
745         return None
746
747     def get_nodes_in_range_inclusive(self, lower: Comparable, upper: Comparable):
748         """
749         >>> t = BinarySearchTree()
750         >>> t.insert(50)
751         >>> t.insert(75)
752         >>> t.insert(25)
753         >>> t.insert(66)
754         >>> t.insert(22)
755         >>> t.insert(13)
756         >>> t.insert(23)
757         >>> t
758         50
759         ├──25
760         │  └──22
761         │     ├──13
762         │     └──23
763         └──75
764            └──66
765
766         >>> for node in t.get_nodes_in_range_inclusive(21, 74):
767         ...     print(node.value)
768         22
769         23
770         25
771         50
772         66
773         """
774         node: Optional[Node] = self._find_lowest_node_greater_than_or_equal_to(
775             lower, self.root
776         )
777         while node:
778             if lower <= node.value <= upper:
779                 yield node
780             node = self.get_next_node(node)
781
782     def _depth(self, node: Node, sofar: int) -> int:
783         depth_left = sofar + 1
784         depth_right = sofar + 1
785         if node.left is not None:
786             depth_left = self._depth(node.left, sofar + 1)
787         if node.right is not None:
788             depth_right = self._depth(node.right, sofar + 1)
789         return max(depth_left, depth_right)
790
791     def depth(self) -> int:
792         """
793         Returns:
794             The max height (depth) of the tree in plies (edge distance
795             from root).
796
797         >>> t = BinarySearchTree()
798         >>> t.depth()
799         0
800
801         >>> t.insert(50)
802         >>> t.depth()
803         1
804
805         >>> t.insert(65)
806         >>> t.depth()
807         2
808
809         >>> t.insert(33)
810         >>> t.depth()
811         2
812
813         >>> t.insert(2)
814         >>> t.insert(1)
815         >>> t.depth()
816         4
817
818         """
819         if self.root is None:
820             return 0
821         return self._depth(self.root, 0)
822
823     def height(self) -> int:
824         """Returns the height (i.e. max depth) of the tree"""
825         return self.depth()
826
827     def repr_traverse(
828         self,
829         padding: str,
830         pointer: str,
831         node: Optional[Node],
832         has_right_sibling: bool,
833     ) -> str:
834         if node is not None:
835             viz = f"\n{padding}{pointer}{node.value}"
836             if has_right_sibling:
837                 padding += "│  "
838             else:
839                 padding += "   "
840
841             pointer_right = "└──"
842             if node.right is not None:
843                 pointer_left = "├──"
844             else:
845                 pointer_left = "└──"
846
847             viz += self.repr_traverse(
848                 padding, pointer_left, node.left, node.right is not None
849             )
850             viz += self.repr_traverse(padding, pointer_right, node.right, False)
851             return viz
852         return ""
853
854     def __repr__(self):
855         """
856         Returns:
857             An ASCII string representation of the tree.
858
859         >>> t = BinarySearchTree()
860         >>> t.insert(50)
861         >>> t.insert(25)
862         >>> t.insert(75)
863         >>> t.insert(12)
864         >>> t.insert(33)
865         >>> t.insert(88)
866         >>> t.insert(55)
867         >>> t
868         50
869         ├──25
870         │  ├──12
871         │  └──33
872         └──75
873            ├──55
874            └──88
875         """
876         if self.root is None:
877             return ""
878
879         ret = f"{self.root.value}"
880         pointer_right = "└──"
881         if self.root.right is None:
882             pointer_left = "└──"
883         else:
884             pointer_left = "├──"
885
886         ret += self.repr_traverse(
887             "", pointer_left, self.root.left, self.root.left is not None
888         )
889         ret += self.repr_traverse("", pointer_right, self.root.right, False)
890         return ret
891
892
893 if __name__ == "__main__":
894     import doctest
895
896     doctest.testmod()