Adds a __repr__ to graph.
[pyutils.git] / src / pyutils / collectionz / bst.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2023, Scott Gasch
4
5 """A binary search tree implementation."""
6
7 from typing import Generator, List, Optional
8
9 from pyutils.typez.typing import Comparable
10
11
12 class Node:
13     def __init__(self, value: Comparable) -> None:
14         """A BST node.  Just a left and right reference along with a
15         value.  Note that value can be anything as long as it
16         is :class:`Comparable` with other instances of itself.
17
18         Args:
19             value: a reference to the value of the node.  Must be
20                 :class:`Comparable` to other values.
21
22         """
23         self.left: Optional[Node] = None
24         self.right: Optional[Node] = None
25         self.value: Comparable = value
26
27
28 class BinarySearchTree(object):
29     def __init__(self):
30         self.root = None
31         self.count = 0
32         self.traverse = None
33
34     def get_root(self) -> Optional[Node]:
35         """
36         Returns:
37             The root of the BST
38         """
39
40         return self.root
41
42     def _on_insert(self, parent: Optional[Node], new: Node) -> None:
43         """This is called immediately _after_ a new node is inserted."""
44         pass
45
46     def insert(self, value: Comparable) -> None:
47         """
48         Insert something into the tree in :math:`O(log_2 n)` time.
49
50         Args:
51             value: the value to be inserted.
52
53         >>> t = BinarySearchTree()
54         >>> t.insert(10)
55         >>> t.insert(20)
56         >>> t.insert(5)
57         >>> len(t)
58         3
59
60         >>> t.get_root().value
61         10
62
63         """
64         if self.root is None:
65             self.root = Node(value)
66             self.count = 1
67             self._on_insert(None, self.root)
68         else:
69             self._insert(value, self.root)
70
71     def _insert(self, value: Comparable, node: Node):
72         """Insertion helper"""
73         if value < node.value:
74             if node.left is not None:
75                 self._insert(value, node.left)
76             else:
77                 node.left = Node(value)
78                 self.count += 1
79                 self._on_insert(node, node.left)
80         else:
81             if node.right is not None:
82                 self._insert(value, node.right)
83             else:
84                 node.right = Node(value)
85                 self.count += 1
86                 self._on_insert(node, node.right)
87
88     def __getitem__(self, value: Comparable) -> Optional[Node]:
89         """
90         Find an item in the tree and return its Node in
91         :math:`O(log_2 n)` time.  Returns None if the item is not in
92         the tree.
93
94         >>> t = BinarySearchTree()
95         >>> t[99]
96
97         >>> t.insert(10)
98         >>> t.insert(20)
99         >>> t.insert(5)
100         >>> t[10].value
101         10
102
103         >>> t[99]
104
105         """
106         if self.root is not None:
107             return self._find_exact(value, self.root)
108         return None
109
110     def _find_exact(self, target: Comparable, node: Node) -> Optional[Node]:
111         """Recursively traverse the tree looking for a node with the
112         target value.  Return that node if it exists, otherwise return
113         None."""
114
115         if target == node.value:
116             return node
117         elif target < node.value and node.left is not None:
118             return self._find_exact(target, node.left)
119         elif target > node.value and node.right is not None:
120             return self._find_exact(target, node.right)
121         return None
122
123     def _find_lowest_node_less_than_or_equal_to(
124         self, target: Comparable, node: Optional[Node]
125     ) -> Optional[Node]:
126         """Find helper that returns the lowest node that is less
127         than or equal to the target value.  Returns None if target is
128         lower than the lowest node in the tree.
129
130         >>> t = BinarySearchTree()
131         >>> t.insert(50)
132         >>> t.insert(75)
133         >>> t.insert(25)
134         >>> t.insert(66)
135         >>> t.insert(22)
136         >>> t.insert(13)
137         >>> t.insert(85)
138         >>> t
139         50
140         ├──25
141         │  └──22
142         │     └──13
143         └──75
144            ├──66
145            └──85
146
147         >>> t._find_lowest_node_less_than_or_equal_to(48, t.root).value
148         25
149         >>> t._find_lowest_node_less_than_or_equal_to(55, t.root).value
150         50
151         >>> t._find_lowest_node_less_than_or_equal_to(100, t.root).value
152         85
153         >>> t._find_lowest_node_less_than_or_equal_to(24, t.root).value
154         22
155         >>> t._find_lowest_node_less_than_or_equal_to(20, t.root).value
156         13
157         >>> t._find_lowest_node_less_than_or_equal_to(72, t.root).value
158         66
159         >>> t._find_lowest_node_less_than_or_equal_to(78, t.root).value
160         75
161         >>> t._find_lowest_node_less_than_or_equal_to(12, t.root) is None
162         True
163
164         """
165
166         if not node:
167             return None
168
169         if target == node.value:
170             return node
171
172         elif target > node.value:
173             if below := self._find_lowest_node_less_than_or_equal_to(
174                 target, node.right
175             ):
176                 return below
177             else:
178                 return node
179
180         else:
181             return self._find_lowest_node_less_than_or_equal_to(target, node.left)
182
183     def _find_lowest_node_greater_than_or_equal_to(
184         self, target: Comparable, node: Optional[Node]
185     ) -> Optional[Node]:
186         """Find helper that returns the lowest node that is greater
187         than or equal to the target value.  Returns None if target is
188         higher than the greatest node in the tree.
189
190         >>> t = BinarySearchTree()
191         >>> t.insert(50)
192         >>> t.insert(75)
193         >>> t.insert(25)
194         >>> t.insert(66)
195         >>> t.insert(22)
196         >>> t.insert(13)
197         >>> t.insert(85)
198         >>> t
199         50
200         ├──25
201         │  └──22
202         │     └──13
203         └──75
204            ├──66
205            └──85
206
207         >>> t._find_lowest_node_greater_than_or_equal_to(48, t.root).value
208         50
209         >>> t._find_lowest_node_greater_than_or_equal_to(55, t.root).value
210         66
211         >>> t._find_lowest_node_greater_than_or_equal_to(1, t.root).value
212         13
213         >>> t._find_lowest_node_greater_than_or_equal_to(24, t.root).value
214         25
215         >>> t._find_lowest_node_greater_than_or_equal_to(20, t.root).value
216         22
217         >>> t._find_lowest_node_greater_than_or_equal_to(72, t.root).value
218         75
219         >>> t._find_lowest_node_greater_than_or_equal_to(78, t.root).value
220         85
221         >>> t._find_lowest_node_greater_than_or_equal_to(95, t.root) is None
222         True
223
224         """
225
226         if not node:
227             return None
228
229         if target == node.value:
230             return node
231
232         elif target > node.value:
233             return self._find_lowest_node_greater_than_or_equal_to(target, node.right)
234
235         # If target < this node's value, either this node is the
236         # answer or the answer is in this node's left subtree.
237         else:
238             if below := self._find_lowest_node_greater_than_or_equal_to(
239                 target, node.left
240             ):
241                 return below
242             else:
243                 return node
244
245     def _parent_path(
246         self, current: Optional[Node], target: Node
247     ) -> List[Optional[Node]]:
248         """Internal helper"""
249         if current is None:
250             return [None]
251         ret: List[Optional[Node]] = [current]
252         if target.value == current.value:
253             return ret
254         elif target.value < current.value:
255             ret.extend(self._parent_path(current.left, target))
256             return ret
257         else:
258             assert target.value > current.value
259             ret.extend(self._parent_path(current.right, target))
260             return ret
261
262     def parent_path(self, node: Node) -> List[Optional[Node]]:
263         """Get a node's parent path in :math:`O(log_2 n)` time.
264
265         Args:
266             node: the node whose parent path should be returned.
267
268         Returns:
269             a list of nodes representing the path from
270             the tree's root to the given node.
271
272         .. note::
273
274             If the node does not exist in the tree, the last element
275             on the path will be None but the path will indicate the
276             ancestor path of that node were it to be inserted.
277
278         >>> t = BinarySearchTree()
279         >>> t.insert(50)
280         >>> t.insert(75)
281         >>> t.insert(25)
282         >>> t.insert(12)
283         >>> t.insert(33)
284         >>> t.insert(4)
285         >>> t.insert(88)
286         >>> t
287         50
288         ├──25
289         │  ├──12
290         │  │  └──4
291         │  └──33
292         └──75
293            └──88
294
295         >>> n = t[4]
296         >>> for x in t.parent_path(n):
297         ...     print(x.value)
298         50
299         25
300         12
301         4
302
303         >>> del t[4]
304         >>> for x in t.parent_path(n):
305         ...     if x is not None:
306         ...         print(x.value)
307         ...     else:
308         ...         print(x)
309         50
310         25
311         12
312         None
313
314         """
315         return self._parent_path(self.root, node)
316
317     def __delitem__(self, value: Comparable) -> bool:
318         """
319         Delete an item from the tree and preserve the BST property in
320         :math:`O(log_2 n) time`.
321
322         Args:
323             value: the value of the node to be deleted.
324
325         Returns:
326             True if the value was found and its associated node was
327             successfully deleted and False otherwise.
328
329         >>> t = BinarySearchTree()
330         >>> t.insert(50)
331         >>> t.insert(75)
332         >>> t.insert(25)
333         >>> t.insert(66)
334         >>> t.insert(22)
335         >>> t.insert(13)
336         >>> t.insert(85)
337         >>> t
338         50
339         ├──25
340         │  └──22
341         │     └──13
342         └──75
343            ├──66
344            └──85
345
346         >>> for value in t.iterate_inorder():
347         ...     print(value)
348         13
349         22
350         25
351         50
352         66
353         75
354         85
355
356         >>> del t[22]  # Note: bool result is discarded
357
358         >>> for value in t.iterate_inorder():
359         ...     print(value)
360         13
361         25
362         50
363         66
364         75
365         85
366
367         >>> t.__delitem__(13)
368         True
369         >>> for value in t.iterate_inorder():
370         ...     print(value)
371         25
372         50
373         66
374         75
375         85
376
377         >>> t.__delitem__(75)
378         True
379         >>> for value in t.iterate_inorder():
380         ...     print(value)
381         25
382         50
383         66
384         85
385         >>> t
386         50
387         ├──25
388         └──85
389            └──66
390
391         >>> t.__delitem__(85)
392         True
393
394         >>> t.__delitem__(99)
395         False
396
397         """
398         if self.root is not None:
399             ret = self._delete(value, None, self.root)
400             if ret:
401                 self.count -= 1
402                 if self.count == 0:
403                     self.root = None
404             return ret
405         return False
406
407     def _on_delete(self, parent: Optional[Node], deleted: Node) -> None:
408         """This is called just after deleted was deleted from the tree"""
409         pass
410
411     def _delete(self, value: Comparable, parent: Optional[Node], node: Node) -> bool:
412         """Delete helper"""
413         if node.value == value:
414
415             # Deleting a leaf node
416             if node.left is None and node.right is None:
417                 if parent is not None:
418                     if parent.left == node:
419                         parent.left = None
420                     else:
421                         assert parent.right == node
422                         parent.right = None
423                 self._on_delete(parent, node)
424                 return True
425
426             # Node only has a right.
427             elif node.left is None:
428                 assert node.right is not None
429                 if parent is not None:
430                     if parent.left == node:
431                         parent.left = node.right
432                     else:
433                         assert parent.right == node
434                         parent.right = node.right
435                 self._on_delete(parent, node)
436                 return True
437
438             # Node only has a left.
439             elif node.right is None:
440                 assert node.left is not None
441                 if parent is not None:
442                     if parent.left == node:
443                         parent.left = node.left
444                     else:
445                         assert parent.right == node
446                         parent.right = node.left
447                 self._on_delete(parent, node)
448                 return True
449
450             # Node has both a left and right; get the successor node
451             # to this one and put it here (deleting the successor's
452             # old node).  Because these operations are happening only
453             # in the subtree underneath of node, I'm still calling
454             # this delete an O(log_2 n) operation in the docs.
455             else:
456                 assert node.left is not None and node.right is not None
457                 successor = self.get_next_node(node)
458                 assert successor is not None
459                 node.value = successor.value
460                 return self._delete(node.value, node, node.right)
461
462         elif value < node.value and node.left is not None:
463             return self._delete(value, node, node.left)
464         elif value > node.value and node.right is not None:
465             return self._delete(value, node, node.right)
466         return False
467
468     def __len__(self):
469         """
470         Returns:
471             The count of items in the tree in :math:`O(1)` time.
472
473         >>> t = BinarySearchTree()
474         >>> len(t)
475         0
476         >>> t.insert(50)
477         >>> len(t)
478         1
479         >>> t.__delitem__(50)
480         True
481         >>> len(t)
482         0
483         >>> t.insert(75)
484         >>> t.insert(25)
485         >>> t.insert(66)
486         >>> t.insert(22)
487         >>> t.insert(13)
488         >>> t.insert(85)
489         >>> len(t)
490         6
491
492         """
493         return self.count
494
495     def __contains__(self, value: Comparable) -> bool:
496         """
497         Returns:
498             True if the item is in the tree; False otherwise.
499         """
500         return self.__getitem__(value) is not None
501
502     def _iterate_preorder(self, node: Node):
503         yield node.value
504         if node.left is not None:
505             yield from self._iterate_preorder(node.left)
506         if node.right is not None:
507             yield from self._iterate_preorder(node.right)
508
509     def _iterate_inorder(self, node: Node):
510         if node.left is not None:
511             yield from self._iterate_inorder(node.left)
512         yield node.value
513         if node.right is not None:
514             yield from self._iterate_inorder(node.right)
515
516     def _iterate_postorder(self, node: Node):
517         if node.left is not None:
518             yield from self._iterate_postorder(node.left)
519         if node.right is not None:
520             yield from self._iterate_postorder(node.right)
521         yield node.value
522
523     def iterate_preorder(self):
524         """
525         Returns:
526             A Generator that yields the tree's items in a
527             preorder traversal sequence.
528
529         >>> t = BinarySearchTree()
530         >>> t.insert(50)
531         >>> t.insert(75)
532         >>> t.insert(25)
533         >>> t.insert(66)
534         >>> t.insert(22)
535         >>> t.insert(13)
536
537         >>> for value in t.iterate_preorder():
538         ...     print(value)
539         50
540         25
541         22
542         13
543         75
544         66
545
546         """
547         if self.root is not None:
548             yield from self._iterate_preorder(self.root)
549
550     def iterate_inorder(self):
551         """
552         Returns:
553             A Generator that yield the tree's items in a preorder
554             traversal sequence.
555
556         >>> t = BinarySearchTree()
557         >>> t.insert(50)
558         >>> t.insert(75)
559         >>> t.insert(25)
560         >>> t.insert(66)
561         >>> t.insert(22)
562         >>> t.insert(13)
563         >>> t.insert(24)
564         >>> t
565         50
566         ├──25
567         │  └──22
568         │     ├──13
569         │     └──24
570         └──75
571            └──66
572
573         >>> for value in t.iterate_inorder():
574         ...     print(value)
575         13
576         22
577         24
578         25
579         50
580         66
581         75
582
583         """
584         if self.root is not None:
585             yield from self._iterate_inorder(self.root)
586
587     def iterate_postorder(self):
588         """
589         Returns:
590             A Generator that yield the tree's items in a preorder
591             traversal sequence.
592
593         >>> t = BinarySearchTree()
594         >>> t.insert(50)
595         >>> t.insert(75)
596         >>> t.insert(25)
597         >>> t.insert(66)
598         >>> t.insert(22)
599         >>> t.insert(13)
600
601         >>> for value in t.iterate_postorder():
602         ...     print(value)
603         13
604         22
605         25
606         66
607         75
608         50
609
610         """
611         if self.root is not None:
612             yield from self._iterate_postorder(self.root)
613
614     def _iterate_leaves(self, node: Node):
615         if node.left is not None:
616             yield from self._iterate_leaves(node.left)
617         if node.right is not None:
618             yield from self._iterate_leaves(node.right)
619         if node.left is None and node.right is None:
620             yield node.value
621
622     def iterate_leaves(self):
623         """
624         Returns:
625             A Generator that yields only the leaf nodes in the
626             tree.
627
628         >>> t = BinarySearchTree()
629         >>> t.insert(50)
630         >>> t.insert(75)
631         >>> t.insert(25)
632         >>> t.insert(66)
633         >>> t.insert(22)
634         >>> t.insert(13)
635
636         >>> for value in t.iterate_leaves():
637         ...     print(value)
638         13
639         66
640
641         """
642         if self.root is not None:
643             yield from self._iterate_leaves(self.root)
644
645     def _iterate_by_depth(self, node: Node, depth: int):
646         if depth == 0:
647             yield node.value
648         else:
649             assert depth > 0
650             if node.left is not None:
651                 yield from self._iterate_by_depth(node.left, depth - 1)
652             if node.right is not None:
653                 yield from self._iterate_by_depth(node.right, depth - 1)
654
655     def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
656         """
657         Args:
658             depth: the desired depth
659
660         Returns:
661             A Generator that yields nodes at the prescribed depth in
662             the tree.
663
664         >>> t = BinarySearchTree()
665         >>> t.insert(50)
666         >>> t.insert(75)
667         >>> t.insert(25)
668         >>> t.insert(66)
669         >>> t.insert(22)
670         >>> t.insert(13)
671
672         >>> for value in t.iterate_nodes_by_depth(2):
673         ...     print(value)
674         22
675         66
676
677         >>> for value in t.iterate_nodes_by_depth(3):
678         ...     print(value)
679         13
680
681         """
682         if self.root is not None:
683             yield from self._iterate_by_depth(self.root, depth)
684
685     def get_next_node(self, node: Node) -> Optional[Node]:
686         """
687         Args:
688             node: the node whose next greater successor is desired
689
690         Returns:
691             Given a tree node, returns the next greater node in the tree.
692             If the given node is the greatest node in the tree, returns None.
693
694         >>> t = BinarySearchTree()
695         >>> t.insert(50)
696         >>> t.insert(75)
697         >>> t.insert(25)
698         >>> t.insert(66)
699         >>> t.insert(22)
700         >>> t.insert(13)
701         >>> t.insert(23)
702         >>> t
703         50
704         ├──25
705         │  └──22
706         │     ├──13
707         │     └──23
708         └──75
709            └──66
710
711         >>> n = t[23]
712         >>> t.get_next_node(n).value
713         25
714
715         >>> n = t[50]
716         >>> t.get_next_node(n).value
717         66
718
719         >>> n = t[75]
720         >>> t.get_next_node(n) is None
721         True
722
723         """
724         if node.right is not None:
725             x = node.right
726             while x.left is not None:
727                 x = x.left
728             return x
729
730         path = self.parent_path(node)
731         assert path[-1] is not None
732         assert path[-1] == node
733         path = path[:-1]
734         path.reverse()
735         for ancestor in path:
736             assert ancestor is not None
737             if node != ancestor.right:
738                 return ancestor
739             node = ancestor
740         return None
741
742     def get_nodes_in_range_inclusive(
743         self, lower: Comparable, upper: Comparable
744     ) -> Generator[Node, None, None]:
745         """
746         Args:
747             lower: the lower bound of the desired range.
748             upper: the upper bound of the desired range.
749
750         Returns:
751             Generates a sequence of nodes in the desired range.
752
753         >>> t = BinarySearchTree()
754         >>> t.insert(50)
755         >>> t.insert(75)
756         >>> t.insert(25)
757         >>> t.insert(66)
758         >>> t.insert(22)
759         >>> t.insert(13)
760         >>> t.insert(23)
761         >>> t
762         50
763         ├──25
764         │  └──22
765         │     ├──13
766         │     └──23
767         └──75
768            └──66
769
770         >>> for node in t.get_nodes_in_range_inclusive(21, 74):
771         ...     print(node.value)
772         22
773         23
774         25
775         50
776         66
777         """
778         node: Optional[Node] = self._find_lowest_node_greater_than_or_equal_to(
779             lower, self.root
780         )
781         while node:
782             if lower <= node.value <= upper:
783                 yield node
784             node = self.get_next_node(node)
785
786     def _depth(self, node: Node, sofar: int) -> int:
787         depth_left = sofar + 1
788         depth_right = sofar + 1
789         if node.left is not None:
790             depth_left = self._depth(node.left, sofar + 1)
791         if node.right is not None:
792             depth_right = self._depth(node.right, sofar + 1)
793         return max(depth_left, depth_right)
794
795     def depth(self) -> int:
796         """
797         Returns:
798             The max height (depth) of the tree in plies (edge distance
799             from root) in :math:`O(log_2 n)` time.
800
801         >>> t = BinarySearchTree()
802         >>> t.depth()
803         0
804
805         >>> t.insert(50)
806         >>> t.depth()
807         1
808
809         >>> t.insert(65)
810         >>> t.depth()
811         2
812
813         >>> t.insert(33)
814         >>> t.depth()
815         2
816
817         >>> t.insert(2)
818         >>> t.insert(1)
819         >>> t.depth()
820         4
821
822         """
823         if self.root is None:
824             return 0
825         return self._depth(self.root, 0)
826
827     def height(self) -> int:
828         """Returns the height (i.e. max depth) of the tree"""
829         return self.depth()
830
831     def repr_traverse(
832         self,
833         padding: str,
834         pointer: str,
835         node: Optional[Node],
836         has_right_sibling: bool,
837     ) -> str:
838         if node is not None:
839             viz = f"\n{padding}{pointer}{node.value}"
840             if has_right_sibling:
841                 padding += "│  "
842             else:
843                 padding += "   "
844
845             pointer_right = "└──"
846             if node.right is not None:
847                 pointer_left = "├──"
848             else:
849                 pointer_left = "└──"
850
851             viz += self.repr_traverse(
852                 padding, pointer_left, node.left, node.right is not None
853             )
854             viz += self.repr_traverse(padding, pointer_right, node.right, False)
855             return viz
856         return ""
857
858     def __repr__(self):
859         """
860         Returns:
861             An ASCII string representation of the tree.
862
863         >>> t = BinarySearchTree()
864         >>> t.insert(50)
865         >>> t.insert(25)
866         >>> t.insert(75)
867         >>> t.insert(12)
868         >>> t.insert(33)
869         >>> t.insert(88)
870         >>> t.insert(55)
871         >>> t
872         50
873         ├──25
874         │  ├──12
875         │  └──33
876         └──75
877            ├──55
878            └──88
879         """
880         if self.root is None:
881             return ""
882
883         ret = f"{self.root.value}"
884         pointer_right = "└──"
885         if self.root.right is None:
886             pointer_left = "└──"
887         else:
888             pointer_left = "├──"
889
890         ret += self.repr_traverse(
891             "", pointer_left, self.root.left, self.root.left is not None
892         )
893         ret += self.repr_traverse("", pointer_right, self.root.right, False)
894         return ret
895
896
897 if __name__ == "__main__":
898     import doctest
899
900     doctest.testmod()