Improve type hints in bst.py.
[pyutils.git] / src / pyutils / collectionz / bst.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2023, Scott Gasch
4
5 """A binary search tree implementation."""
6
7 from abc import ABCMeta, abstractmethod
8 from typing import Any, Generator, List, Optional, TypeVar
9
10
11 class Comparable(metaclass=ABCMeta):
12     @abstractmethod
13     def __lt__(self, other: Any) -> bool:
14         pass
15
16     @abstractmethod
17     def __le__(self, other: Any) -> bool:
18         pass
19
20     @abstractmethod
21     def __eq__(self, other: Any) -> bool:
22         pass
23
24
25 ComparableNodeValue = TypeVar('ComparableNodeValue', bound=Comparable)
26
27
28 class Node(object):
29     def __init__(self, value: ComparableNodeValue) -> None:
30         """A BST node.  Note that value can be anything as long as it
31         is comparable with other instances of itself.  Check out
32         :meth:`functools.total_ordering`
33         (https://docs.python.org/3/library/functools.html#functools.total_ordering)
34
35         Args:
36             value: a reference to the value of the node.
37
38         """
39         self.left: Optional[Node] = None
40         self.right: Optional[Node] = None
41         self.value: ComparableNodeValue = value
42
43
44 class BinarySearchTree(object):
45     def __init__(self):
46         self.root = None
47         self.count = 0
48         self.traverse = None
49
50     def get_root(self) -> Optional[Node]:
51         """
52         Returns:
53             The root of the BST
54         """
55
56         return self.root
57
58     def _on_insert(self, parent: Optional[Node], new: Node) -> None:
59         """This is called immediately _after_ a new node is inserted."""
60         pass
61
62     def insert(self, value: ComparableNodeValue) -> None:
63         """
64         Insert something into the tree.
65
66         Args:
67             value: the value to be inserted.
68
69         >>> t = BinarySearchTree()
70         >>> t.insert(10)
71         >>> t.insert(20)
72         >>> t.insert(5)
73         >>> len(t)
74         3
75
76         >>> t.get_root().value
77         10
78
79         """
80         if self.root is None:
81             self.root = Node(value)
82             self.count = 1
83             self._on_insert(None, self.root)
84         else:
85             self._insert(value, self.root)
86
87     def _insert(self, value: ComparableNodeValue, node: Node):
88         """Insertion helper"""
89         if value < node.value:
90             if node.left is not None:
91                 self._insert(value, node.left)
92             else:
93                 node.left = Node(value)
94                 self.count += 1
95                 self._on_insert(node, node.left)
96         else:
97             if node.right is not None:
98                 self._insert(value, node.right)
99             else:
100                 node.right = Node(value)
101                 self.count += 1
102                 self._on_insert(node, node.right)
103
104     def __getitem__(self, value: ComparableNodeValue) -> Optional[Node]:
105         """
106         Find an item in the tree and return its Node.  Returns
107         None if the item is not in the tree.
108
109         >>> t = BinarySearchTree()
110         >>> t[99]
111
112         >>> t.insert(10)
113         >>> t.insert(20)
114         >>> t.insert(5)
115         >>> t[10].value
116         10
117
118         >>> t[99]
119
120         """
121         if self.root is not None:
122             return self._find(value, self.root)
123         return None
124
125     def _find(self, value: ComparableNodeValue, node: Node) -> Optional[Node]:
126         """Find helper"""
127         if value == node.value:
128             return node
129         elif value < node.value and node.left is not None:
130             return self._find(value, node.left)
131         elif value > node.value and node.right is not None:
132             return self._find(value, node.right)
133         return None
134
135     def _find_lowest_value_greater_than_or_equal_to(
136         self, target: ComparableNodeValue, node: Optional[Node]
137     ) -> Optional[Node]:
138         """Find helper that returns the lowest node that is greater
139         than or equal to the target value.  Returns None if target is
140         greater than the highest node in the tree.
141
142         >>> t = BinarySearchTree()
143         >>> t.insert(50)
144         >>> t.insert(75)
145         >>> t.insert(25)
146         >>> t.insert(66)
147         >>> t.insert(22)
148         >>> t.insert(13)
149         >>> t.insert(85)
150         >>> t
151         50
152         ├──25
153         │  └──22
154         │     └──13
155         └──75
156            ├──66
157            └──85
158
159         >>> t._find_lowest_value_greater_than_or_equal_to(48, t.root).value
160         50
161         >>> t._find_lowest_value_greater_than_or_equal_to(55, t.root).value
162         66
163         >>> t._find_lowest_value_greater_than_or_equal_to(1, t.root).value
164         13
165         >>> t._find_lowest_value_greater_than_or_equal_to(24, t.root).value
166         25
167         >>> t._find_lowest_value_greater_than_or_equal_to(20, t.root).value
168         22
169         >>> t._find_lowest_value_greater_than_or_equal_to(72, t.root).value
170         75
171         >>> t._find_lowest_value_greater_than_or_equal_to(78, t.root).value
172         85
173         >>> t._find_lowest_value_greater_than_or_equal_to(95, t.root) is None
174         True
175
176         """
177
178         if not node:
179             return None
180
181         if target == node.value:
182             return node
183         elif target >= node.value:
184             return self._find_lowest_value_greater_than_or_equal_to(target, node.right)
185         else:
186             assert target < node.value
187             if below := self._find_lowest_value_greater_than_or_equal_to(
188                 target, node.left
189             ):
190                 return below
191             else:
192                 return node
193
194     def _parent_path(
195         self, current: Optional[Node], target: Node
196     ) -> List[Optional[Node]]:
197         """Internal helper"""
198         if current is None:
199             return [None]
200         ret: List[Optional[Node]] = [current]
201         if target.value == current.value:
202             return ret
203         elif target.value < current.value:
204             ret.extend(self._parent_path(current.left, target))
205             return ret
206         else:
207             assert target.value > current.value
208             ret.extend(self._parent_path(current.right, target))
209             return ret
210
211     def parent_path(self, node: Node) -> List[Optional[Node]]:
212         """Get a node's parent path.
213
214         Args:
215             node: the node to check
216
217         Returns:
218             a list of nodes representing the path from
219             the tree's root to the node.
220
221         .. note::
222
223             If the node does not exist in the tree, the last element
224             on the path will be None but the path will indicate the
225             ancestor path of that node were it to be inserted.
226
227         >>> t = BinarySearchTree()
228         >>> t.insert(50)
229         >>> t.insert(75)
230         >>> t.insert(25)
231         >>> t.insert(12)
232         >>> t.insert(33)
233         >>> t.insert(4)
234         >>> t.insert(88)
235         >>> t
236         50
237         ├──25
238         │  ├──12
239         │  │  └──4
240         │  └──33
241         └──75
242            └──88
243
244         >>> n = t[4]
245         >>> for x in t.parent_path(n):
246         ...     print(x.value)
247         50
248         25
249         12
250         4
251
252         >>> del t[4]
253         >>> for x in t.parent_path(n):
254         ...     if x is not None:
255         ...         print(x.value)
256         ...     else:
257         ...         print(x)
258         50
259         25
260         12
261         None
262
263         """
264         return self._parent_path(self.root, node)
265
266     def __delitem__(self, value: ComparableNodeValue) -> bool:
267         """
268         Delete an item from the tree and preserve the BST property.
269
270         Args:
271             value: the value of the node to be deleted.
272
273         Returns:
274             True if the value was found and its associated node was
275             successfully deleted and False otherwise.
276
277         >>> t = BinarySearchTree()
278         >>> t.insert(50)
279         >>> t.insert(75)
280         >>> t.insert(25)
281         >>> t.insert(66)
282         >>> t.insert(22)
283         >>> t.insert(13)
284         >>> t.insert(85)
285         >>> t
286         50
287         ├──25
288         │  └──22
289         │     └──13
290         └──75
291            ├──66
292            └──85
293
294         >>> for value in t.iterate_inorder():
295         ...     print(value)
296         13
297         22
298         25
299         50
300         66
301         75
302         85
303
304         >>> del t[22]  # Note: bool result is discarded
305
306         >>> for value in t.iterate_inorder():
307         ...     print(value)
308         13
309         25
310         50
311         66
312         75
313         85
314
315         >>> t.__delitem__(13)
316         True
317         >>> for value in t.iterate_inorder():
318         ...     print(value)
319         25
320         50
321         66
322         75
323         85
324
325         >>> t.__delitem__(75)
326         True
327         >>> for value in t.iterate_inorder():
328         ...     print(value)
329         25
330         50
331         66
332         85
333         >>> t
334         50
335         ├──25
336         └──85
337            └──66
338
339         >>> t.__delitem__(99)
340         False
341
342         """
343         if self.root is not None:
344             ret = self._delete(value, None, self.root)
345             if ret:
346                 self.count -= 1
347                 if self.count == 0:
348                     self.root = None
349             return ret
350         return False
351
352     def _on_delete(self, parent: Optional[Node], deleted: Node) -> None:
353         """This is called just after deleted was deleted from the tree"""
354         pass
355
356     def _delete(
357         self, value: ComparableNodeValue, parent: Optional[Node], node: Node
358     ) -> bool:
359         """Delete helper"""
360         if node.value == value:
361
362             # Deleting a leaf node
363             if node.left is None and node.right is None:
364                 if parent is not None:
365                     if parent.left == node:
366                         parent.left = None
367                     else:
368                         assert parent.right == node
369                         parent.right = None
370                 self._on_delete(parent, node)
371                 return True
372
373             # Node only has a right.
374             elif node.left is None:
375                 assert node.right is not None
376                 if parent is not None:
377                     if parent.left == node:
378                         parent.left = node.right
379                     else:
380                         assert parent.right == node
381                         parent.right = node.right
382                 self._on_delete(parent, node)
383                 return True
384
385             # Node only has a left.
386             elif node.right is None:
387                 assert node.left is not None
388                 if parent is not None:
389                     if parent.left == node:
390                         parent.left = node.left
391                     else:
392                         assert parent.right == node
393                         parent.right = node.left
394                 self._on_delete(parent, node)
395                 return True
396
397             # Node has both a left and right.
398             else:
399                 assert node.left is not None and node.right is not None
400                 descendent = node.right
401                 while descendent.left is not None:
402                     descendent = descendent.left
403                 node.value = descendent.value
404                 return self._delete(node.value, node, node.right)
405         elif value < node.value and node.left is not None:
406             return self._delete(value, node, node.left)
407         elif value > node.value and node.right is not None:
408             return self._delete(value, node, node.right)
409         return False
410
411     def __len__(self):
412         """
413         Returns:
414             The count of items in the tree.
415
416         >>> t = BinarySearchTree()
417         >>> len(t)
418         0
419         >>> t.insert(50)
420         >>> len(t)
421         1
422         >>> t.__delitem__(50)
423         True
424         >>> len(t)
425         0
426         >>> t.insert(75)
427         >>> t.insert(25)
428         >>> t.insert(66)
429         >>> t.insert(22)
430         >>> t.insert(13)
431         >>> t.insert(85)
432         >>> len(t)
433         6
434
435         """
436         return self.count
437
438     def __contains__(self, value: ComparableNodeValue) -> bool:
439         """
440         Returns:
441             True if the item is in the tree; False otherwise.
442         """
443         return self.__getitem__(value) is not None
444
445     def _iterate_preorder(self, node: Node):
446         yield node.value
447         if node.left is not None:
448             yield from self._iterate_preorder(node.left)
449         if node.right is not None:
450             yield from self._iterate_preorder(node.right)
451
452     def _iterate_inorder(self, node: Node):
453         if node.left is not None:
454             yield from self._iterate_inorder(node.left)
455         yield node.value
456         if node.right is not None:
457             yield from self._iterate_inorder(node.right)
458
459     def _iterate_postorder(self, node: Node):
460         if node.left is not None:
461             yield from self._iterate_postorder(node.left)
462         if node.right is not None:
463             yield from self._iterate_postorder(node.right)
464         yield node.value
465
466     def iterate_preorder(self):
467         """
468         Returns:
469             A Generator that yields the tree's items in a
470             preorder traversal sequence.
471
472         >>> t = BinarySearchTree()
473         >>> t.insert(50)
474         >>> t.insert(75)
475         >>> t.insert(25)
476         >>> t.insert(66)
477         >>> t.insert(22)
478         >>> t.insert(13)
479
480         >>> for value in t.iterate_preorder():
481         ...     print(value)
482         50
483         25
484         22
485         13
486         75
487         66
488
489         """
490         if self.root is not None:
491             yield from self._iterate_preorder(self.root)
492
493     def iterate_inorder(self):
494         """
495         Returns:
496             A Generator that yield the tree's items in a preorder
497             traversal sequence.
498
499         >>> t = BinarySearchTree()
500         >>> t.insert(50)
501         >>> t.insert(75)
502         >>> t.insert(25)
503         >>> t.insert(66)
504         >>> t.insert(22)
505         >>> t.insert(13)
506         >>> t.insert(24)
507         >>> t
508         50
509         ├──25
510         │  └──22
511         │     ├──13
512         │     └──24
513         └──75
514            └──66
515
516         >>> for value in t.iterate_inorder():
517         ...     print(value)
518         13
519         22
520         24
521         25
522         50
523         66
524         75
525
526         """
527         if self.root is not None:
528             yield from self._iterate_inorder(self.root)
529
530     def iterate_postorder(self):
531         """
532         Returns:
533             A Generator that yield the tree's items in a preorder
534             traversal sequence.
535
536         >>> t = BinarySearchTree()
537         >>> t.insert(50)
538         >>> t.insert(75)
539         >>> t.insert(25)
540         >>> t.insert(66)
541         >>> t.insert(22)
542         >>> t.insert(13)
543
544         >>> for value in t.iterate_postorder():
545         ...     print(value)
546         13
547         22
548         25
549         66
550         75
551         50
552
553         """
554         if self.root is not None:
555             yield from self._iterate_postorder(self.root)
556
557     def _iterate_leaves(self, node: Node):
558         if node.left is not None:
559             yield from self._iterate_leaves(node.left)
560         if node.right is not None:
561             yield from self._iterate_leaves(node.right)
562         if node.left is None and node.right is None:
563             yield node.value
564
565     def iterate_leaves(self):
566         """
567         Returns:
568             A Gemerator that yielde only the leaf nodes in the
569             tree.
570
571         >>> t = BinarySearchTree()
572         >>> t.insert(50)
573         >>> t.insert(75)
574         >>> t.insert(25)
575         >>> t.insert(66)
576         >>> t.insert(22)
577         >>> t.insert(13)
578
579         >>> for value in t.iterate_leaves():
580         ...     print(value)
581         13
582         66
583
584         """
585         if self.root is not None:
586             yield from self._iterate_leaves(self.root)
587
588     def _iterate_by_depth(self, node: Node, depth: int):
589         if depth == 0:
590             yield node.value
591         else:
592             assert depth > 0
593             if node.left is not None:
594                 yield from self._iterate_by_depth(node.left, depth - 1)
595             if node.right is not None:
596                 yield from self._iterate_by_depth(node.right, depth - 1)
597
598     def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
599         """
600         Args:
601             depth: the desired depth
602
603         Returns:
604             A Generator that yields nodes at the prescribed depth in
605             the tree.
606
607         >>> t = BinarySearchTree()
608         >>> t.insert(50)
609         >>> t.insert(75)
610         >>> t.insert(25)
611         >>> t.insert(66)
612         >>> t.insert(22)
613         >>> t.insert(13)
614
615         >>> for value in t.iterate_nodes_by_depth(2):
616         ...     print(value)
617         22
618         66
619
620         >>> for value in t.iterate_nodes_by_depth(3):
621         ...     print(value)
622         13
623
624         """
625         if self.root is not None:
626             yield from self._iterate_by_depth(self.root, depth)
627
628     def get_next_node(self, node: Node) -> Optional[Node]:
629         """
630         Args:
631             node: the node whose next greater successor is desired
632
633         Returns:
634             Given a tree node, returns the next greater node in the tree.
635             If the given node is the greatest node in the tree, returns None.
636
637         >>> t = BinarySearchTree()
638         >>> t.insert(50)
639         >>> t.insert(75)
640         >>> t.insert(25)
641         >>> t.insert(66)
642         >>> t.insert(22)
643         >>> t.insert(13)
644         >>> t.insert(23)
645         >>> t
646         50
647         ├──25
648         │  └──22
649         │     ├──13
650         │     └──23
651         └──75
652            └──66
653
654         >>> n = t[23]
655         >>> t.get_next_node(n).value
656         25
657
658         >>> n = t[50]
659         >>> t.get_next_node(n).value
660         66
661
662         >>> n = t[75]
663         >>> t.get_next_node(n) is None
664         True
665
666         """
667         if node.right is not None:
668             x = node.right
669             while x.left is not None:
670                 x = x.left
671             return x
672
673         path = self.parent_path(node)
674         assert path[-1] is not None
675         assert path[-1] == node
676         path = path[:-1]
677         path.reverse()
678         for ancestor in path:
679             assert ancestor is not None
680             if node != ancestor.right:
681                 return ancestor
682             node = ancestor
683         return None
684
685     def get_nodes_in_range_inclusive(
686         self, lower: ComparableNodeValue, upper: ComparableNodeValue
687     ):
688         """
689         >>> t = BinarySearchTree()
690         >>> t.insert(50)
691         >>> t.insert(75)
692         >>> t.insert(25)
693         >>> t.insert(66)
694         >>> t.insert(22)
695         >>> t.insert(13)
696         >>> t.insert(23)
697         >>> t
698         50
699         ├──25
700         │  └──22
701         │     ├──13
702         │     └──23
703         └──75
704            └──66
705
706         >>> for node in t.get_nodes_in_range_inclusive(21, 74):
707         ...     print(node.value)
708         22
709         23
710         25
711         50
712         66
713         """
714         node: Optional[Node] = self._find_lowest_value_greater_than_or_equal_to(
715             lower, self.root
716         )
717         while node:
718             if lower <= node.value <= upper:
719                 yield node
720             node = self.get_next_node(node)
721
722     def _depth(self, node: Node, sofar: int) -> int:
723         depth_left = sofar + 1
724         depth_right = sofar + 1
725         if node.left is not None:
726             depth_left = self._depth(node.left, sofar + 1)
727         if node.right is not None:
728             depth_right = self._depth(node.right, sofar + 1)
729         return max(depth_left, depth_right)
730
731     def depth(self) -> int:
732         """
733         Returns:
734             The max height (depth) of the tree in plies (edge distance
735             from root).
736
737         >>> t = BinarySearchTree()
738         >>> t.depth()
739         0
740
741         >>> t.insert(50)
742         >>> t.depth()
743         1
744
745         >>> t.insert(65)
746         >>> t.depth()
747         2
748
749         >>> t.insert(33)
750         >>> t.depth()
751         2
752
753         >>> t.insert(2)
754         >>> t.insert(1)
755         >>> t.depth()
756         4
757
758         """
759         if self.root is None:
760             return 0
761         return self._depth(self.root, 0)
762
763     def height(self) -> int:
764         """Returns the height (i.e. max depth) of the tree"""
765         return self.depth()
766
767     def repr_traverse(
768         self,
769         padding: str,
770         pointer: str,
771         node: Optional[Node],
772         has_right_sibling: bool,
773     ) -> str:
774         if node is not None:
775             viz = f"\n{padding}{pointer}{node.value}"
776             if has_right_sibling:
777                 padding += "│  "
778             else:
779                 padding += "   "
780
781             pointer_right = "└──"
782             if node.right is not None:
783                 pointer_left = "├──"
784             else:
785                 pointer_left = "└──"
786
787             viz += self.repr_traverse(
788                 padding, pointer_left, node.left, node.right is not None
789             )
790             viz += self.repr_traverse(padding, pointer_right, node.right, False)
791             return viz
792         return ""
793
794     def __repr__(self):
795         """
796         Returns:
797             An ASCII string representation of the tree.
798
799         >>> t = BinarySearchTree()
800         >>> t.insert(50)
801         >>> t.insert(25)
802         >>> t.insert(75)
803         >>> t.insert(12)
804         >>> t.insert(33)
805         >>> t.insert(88)
806         >>> t.insert(55)
807         >>> t
808         50
809         ├──25
810         │  ├──12
811         │  └──33
812         └──75
813            ├──55
814            └──88
815         """
816         if self.root is None:
817             return ""
818
819         ret = f"{self.root.value}"
820         pointer_right = "└──"
821         if self.root.right is None:
822             pointer_left = "└──"
823         else:
824             pointer_left = "├──"
825
826         ret += self.repr_traverse(
827             "", pointer_left, self.root.left, self.root.left is not None
828         )
829         ret += self.repr_traverse("", pointer_right, self.root.right, False)
830         return ret
831
832
833 if __name__ == "__main__":
834     import doctest
835
836     doctest.testmod()