Adds find_lowest_node_less_than_or_equal_to.
[pyutils.git] / src / pyutils / collectionz / bst.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2023, Scott Gasch
4
5 """A binary search tree implementation."""
6
7 from abc import ABCMeta, abstractmethod
8 from typing import Any, Generator, List, Optional, TypeVar
9
10
11 class Comparable(metaclass=ABCMeta):
12     @abstractmethod
13     def __lt__(self, other: Any) -> bool:
14         pass
15
16     @abstractmethod
17     def __le__(self, other: Any) -> bool:
18         pass
19
20     @abstractmethod
21     def __eq__(self, other: Any) -> bool:
22         pass
23
24
25 ComparableNodeValue = TypeVar('ComparableNodeValue', bound=Comparable)
26
27
28 class Node(object):
29     def __init__(self, value: ComparableNodeValue) -> None:
30         """A BST node.  Note that value can be anything as long as it
31         is comparable with other instances of itself.  Check out
32         :meth:`functools.total_ordering`
33         (https://docs.python.org/3/library/functools.html#functools.total_ordering)
34
35         Args:
36             value: a reference to the value of the node.
37
38         """
39         self.left: Optional[Node] = None
40         self.right: Optional[Node] = None
41         self.value: ComparableNodeValue = value
42
43
44 class BinarySearchTree(object):
45     def __init__(self):
46         self.root = None
47         self.count = 0
48         self.traverse = None
49
50     def get_root(self) -> Optional[Node]:
51         """
52         Returns:
53             The root of the BST
54         """
55
56         return self.root
57
58     def _on_insert(self, parent: Optional[Node], new: Node) -> None:
59         """This is called immediately _after_ a new node is inserted."""
60         pass
61
62     def insert(self, value: ComparableNodeValue) -> None:
63         """
64         Insert something into the tree.
65
66         Args:
67             value: the value to be inserted.
68
69         >>> t = BinarySearchTree()
70         >>> t.insert(10)
71         >>> t.insert(20)
72         >>> t.insert(5)
73         >>> len(t)
74         3
75
76         >>> t.get_root().value
77         10
78
79         """
80         if self.root is None:
81             self.root = Node(value)
82             self.count = 1
83             self._on_insert(None, self.root)
84         else:
85             self._insert(value, self.root)
86
87     def _insert(self, value: ComparableNodeValue, node: Node):
88         """Insertion helper"""
89         if value < node.value:
90             if node.left is not None:
91                 self._insert(value, node.left)
92             else:
93                 node.left = Node(value)
94                 self.count += 1
95                 self._on_insert(node, node.left)
96         else:
97             if node.right is not None:
98                 self._insert(value, node.right)
99             else:
100                 node.right = Node(value)
101                 self.count += 1
102                 self._on_insert(node, node.right)
103
104     def __getitem__(self, value: ComparableNodeValue) -> Optional[Node]:
105         """
106         Find an item in the tree and return its Node.  Returns
107         None if the item is not in the tree.
108
109         >>> t = BinarySearchTree()
110         >>> t[99]
111
112         >>> t.insert(10)
113         >>> t.insert(20)
114         >>> t.insert(5)
115         >>> t[10].value
116         10
117
118         >>> t[99]
119
120         """
121         if self.root is not None:
122             return self._find_exact(value, self.root)
123         return None
124
125     def _find_exact(self, target: ComparableNodeValue, node: Node) -> Optional[Node]:
126         """Recursively traverse the tree looking for a node with the
127         target value.  Return that node if it exists, otherwise return
128         None."""
129
130         if target == node.value:
131             return node
132         elif target < node.value and node.left is not None:
133             return self._find_exact(target, node.left)
134         elif target > node.value and node.right is not None:
135             return self._find_exact(target, node.right)
136         return None
137
138     def _find_lowest_node_less_than_or_equal_to(
139         self, target: ComparableNodeValue, node: Optional[Node]
140     ) -> Optional[Node]:
141         """Find helper that returns the lowest node that is less
142         than or equal to the target value.  Returns None if target is
143         lower than the lowest node in the tree.
144
145         >>> t = BinarySearchTree()
146         >>> t.insert(50)
147         >>> t.insert(75)
148         >>> t.insert(25)
149         >>> t.insert(66)
150         >>> t.insert(22)
151         >>> t.insert(13)
152         >>> t.insert(85)
153         >>> t
154         50
155         ├──25
156         │  └──22
157         │     └──13
158         └──75
159            ├──66
160            └──85
161
162         >>> t._find_lowest_node_less_than_or_equal_to(48, t.root).value
163         25
164         >>> t._find_lowest_node_less_than_or_equal_to(55, t.root).value
165         50
166         >>> t._find_lowest_node_less_than_or_equal_to(100, t.root).value
167         85
168         >>> t._find_lowest_node_less_than_or_equal_to(24, t.root).value
169         22
170         >>> t._find_lowest_node_less_than_or_equal_to(20, t.root).value
171         13
172         >>> t._find_lowest_node_less_than_or_equal_to(72, t.root).value
173         66
174         >>> t._find_lowest_node_less_than_or_equal_to(78, t.root).value
175         75
176         >>> t._find_lowest_node_less_than_or_equal_to(12, t.root) is None
177         True
178
179         """
180
181         if not node:
182             return None
183
184         if target == node.value:
185             return node
186
187         elif target > node.value:
188             if below := self._find_lowest_node_less_than_or_equal_to(
189                 target, node.right
190             ):
191                 return below
192             else:
193                 return node
194
195         else:
196             return self._find_lowest_node_less_than_or_equal_to(target, node.left)
197
198     def _find_lowest_node_greater_than_or_equal_to(
199         self, target: ComparableNodeValue, node: Optional[Node]
200     ) -> Optional[Node]:
201         """Find helper that returns the lowest node that is greater
202         than or equal to the target value.  Returns None if target is
203         higher than the greatest node in the tree.
204
205         >>> t = BinarySearchTree()
206         >>> t.insert(50)
207         >>> t.insert(75)
208         >>> t.insert(25)
209         >>> t.insert(66)
210         >>> t.insert(22)
211         >>> t.insert(13)
212         >>> t.insert(85)
213         >>> t
214         50
215         ├──25
216         │  └──22
217         │     └──13
218         └──75
219            ├──66
220            └──85
221
222         >>> t._find_lowest_node_greater_than_or_equal_to(48, t.root).value
223         50
224         >>> t._find_lowest_node_greater_than_or_equal_to(55, t.root).value
225         66
226         >>> t._find_lowest_node_greater_than_or_equal_to(1, t.root).value
227         13
228         >>> t._find_lowest_node_greater_than_or_equal_to(24, t.root).value
229         25
230         >>> t._find_lowest_node_greater_than_or_equal_to(20, t.root).value
231         22
232         >>> t._find_lowest_node_greater_than_or_equal_to(72, t.root).value
233         75
234         >>> t._find_lowest_node_greater_than_or_equal_to(78, t.root).value
235         85
236         >>> t._find_lowest_node_greater_than_or_equal_to(95, t.root) is None
237         True
238
239         """
240
241         if not node:
242             return None
243
244         if target == node.value:
245             return node
246
247         elif target > node.value:
248             return self._find_lowest_node_greater_than_or_equal_to(target, node.right)
249
250         # If target < this node's value, either this node is the
251         # answer or the answer is in this node's left subtree.
252         else:
253             if below := self._find_lowest_node_greater_than_or_equal_to(
254                 target, node.left
255             ):
256                 return below
257             else:
258                 return node
259
260     def _parent_path(
261         self, current: Optional[Node], target: Node
262     ) -> List[Optional[Node]]:
263         """Internal helper"""
264         if current is None:
265             return [None]
266         ret: List[Optional[Node]] = [current]
267         if target.value == current.value:
268             return ret
269         elif target.value < current.value:
270             ret.extend(self._parent_path(current.left, target))
271             return ret
272         else:
273             assert target.value > current.value
274             ret.extend(self._parent_path(current.right, target))
275             return ret
276
277     def parent_path(self, node: Node) -> List[Optional[Node]]:
278         """Get a node's parent path.
279
280         Args:
281             node: the node to check
282
283         Returns:
284             a list of nodes representing the path from
285             the tree's root to the node.
286
287         .. note::
288
289             If the node does not exist in the tree, the last element
290             on the path will be None but the path will indicate the
291             ancestor path of that node were it to be inserted.
292
293         >>> t = BinarySearchTree()
294         >>> t.insert(50)
295         >>> t.insert(75)
296         >>> t.insert(25)
297         >>> t.insert(12)
298         >>> t.insert(33)
299         >>> t.insert(4)
300         >>> t.insert(88)
301         >>> t
302         50
303         ├──25
304         │  ├──12
305         │  │  └──4
306         │  └──33
307         └──75
308            └──88
309
310         >>> n = t[4]
311         >>> for x in t.parent_path(n):
312         ...     print(x.value)
313         50
314         25
315         12
316         4
317
318         >>> del t[4]
319         >>> for x in t.parent_path(n):
320         ...     if x is not None:
321         ...         print(x.value)
322         ...     else:
323         ...         print(x)
324         50
325         25
326         12
327         None
328
329         """
330         return self._parent_path(self.root, node)
331
332     def __delitem__(self, value: ComparableNodeValue) -> bool:
333         """
334         Delete an item from the tree and preserve the BST property.
335
336         Args:
337             value: the value of the node to be deleted.
338
339         Returns:
340             True if the value was found and its associated node was
341             successfully deleted and False otherwise.
342
343         >>> t = BinarySearchTree()
344         >>> t.insert(50)
345         >>> t.insert(75)
346         >>> t.insert(25)
347         >>> t.insert(66)
348         >>> t.insert(22)
349         >>> t.insert(13)
350         >>> t.insert(85)
351         >>> t
352         50
353         ├──25
354         │  └──22
355         │     └──13
356         └──75
357            ├──66
358            └──85
359
360         >>> for value in t.iterate_inorder():
361         ...     print(value)
362         13
363         22
364         25
365         50
366         66
367         75
368         85
369
370         >>> del t[22]  # Note: bool result is discarded
371
372         >>> for value in t.iterate_inorder():
373         ...     print(value)
374         13
375         25
376         50
377         66
378         75
379         85
380
381         >>> t.__delitem__(13)
382         True
383         >>> for value in t.iterate_inorder():
384         ...     print(value)
385         25
386         50
387         66
388         75
389         85
390
391         >>> t.__delitem__(75)
392         True
393         >>> for value in t.iterate_inorder():
394         ...     print(value)
395         25
396         50
397         66
398         85
399         >>> t
400         50
401         ├──25
402         └──85
403            └──66
404
405         >>> t.__delitem__(99)
406         False
407
408         """
409         if self.root is not None:
410             ret = self._delete(value, None, self.root)
411             if ret:
412                 self.count -= 1
413                 if self.count == 0:
414                     self.root = None
415             return ret
416         return False
417
418     def _on_delete(self, parent: Optional[Node], deleted: Node) -> None:
419         """This is called just after deleted was deleted from the tree"""
420         pass
421
422     def _delete(
423         self, value: ComparableNodeValue, parent: Optional[Node], node: Node
424     ) -> bool:
425         """Delete helper"""
426         if node.value == value:
427
428             # Deleting a leaf node
429             if node.left is None and node.right is None:
430                 if parent is not None:
431                     if parent.left == node:
432                         parent.left = None
433                     else:
434                         assert parent.right == node
435                         parent.right = None
436                 self._on_delete(parent, node)
437                 return True
438
439             # Node only has a right.
440             elif node.left is None:
441                 assert node.right is not None
442                 if parent is not None:
443                     if parent.left == node:
444                         parent.left = node.right
445                     else:
446                         assert parent.right == node
447                         parent.right = node.right
448                 self._on_delete(parent, node)
449                 return True
450
451             # Node only has a left.
452             elif node.right is None:
453                 assert node.left is not None
454                 if parent is not None:
455                     if parent.left == node:
456                         parent.left = node.left
457                     else:
458                         assert parent.right == node
459                         parent.right = node.left
460                 self._on_delete(parent, node)
461                 return True
462
463             # Node has both a left and right.
464             else:
465                 assert node.left is not None and node.right is not None
466                 descendent = node.right
467                 while descendent.left is not None:
468                     descendent = descendent.left
469                 node.value = descendent.value
470                 return self._delete(node.value, node, node.right)
471         elif value < node.value and node.left is not None:
472             return self._delete(value, node, node.left)
473         elif value > node.value and node.right is not None:
474             return self._delete(value, node, node.right)
475         return False
476
477     def __len__(self):
478         """
479         Returns:
480             The count of items in the tree.
481
482         >>> t = BinarySearchTree()
483         >>> len(t)
484         0
485         >>> t.insert(50)
486         >>> len(t)
487         1
488         >>> t.__delitem__(50)
489         True
490         >>> len(t)
491         0
492         >>> t.insert(75)
493         >>> t.insert(25)
494         >>> t.insert(66)
495         >>> t.insert(22)
496         >>> t.insert(13)
497         >>> t.insert(85)
498         >>> len(t)
499         6
500
501         """
502         return self.count
503
504     def __contains__(self, value: ComparableNodeValue) -> bool:
505         """
506         Returns:
507             True if the item is in the tree; False otherwise.
508         """
509         return self.__getitem__(value) is not None
510
511     def _iterate_preorder(self, node: Node):
512         yield node.value
513         if node.left is not None:
514             yield from self._iterate_preorder(node.left)
515         if node.right is not None:
516             yield from self._iterate_preorder(node.right)
517
518     def _iterate_inorder(self, node: Node):
519         if node.left is not None:
520             yield from self._iterate_inorder(node.left)
521         yield node.value
522         if node.right is not None:
523             yield from self._iterate_inorder(node.right)
524
525     def _iterate_postorder(self, node: Node):
526         if node.left is not None:
527             yield from self._iterate_postorder(node.left)
528         if node.right is not None:
529             yield from self._iterate_postorder(node.right)
530         yield node.value
531
532     def iterate_preorder(self):
533         """
534         Returns:
535             A Generator that yields the tree's items in a
536             preorder traversal sequence.
537
538         >>> t = BinarySearchTree()
539         >>> t.insert(50)
540         >>> t.insert(75)
541         >>> t.insert(25)
542         >>> t.insert(66)
543         >>> t.insert(22)
544         >>> t.insert(13)
545
546         >>> for value in t.iterate_preorder():
547         ...     print(value)
548         50
549         25
550         22
551         13
552         75
553         66
554
555         """
556         if self.root is not None:
557             yield from self._iterate_preorder(self.root)
558
559     def iterate_inorder(self):
560         """
561         Returns:
562             A Generator that yield the tree's items in a preorder
563             traversal sequence.
564
565         >>> t = BinarySearchTree()
566         >>> t.insert(50)
567         >>> t.insert(75)
568         >>> t.insert(25)
569         >>> t.insert(66)
570         >>> t.insert(22)
571         >>> t.insert(13)
572         >>> t.insert(24)
573         >>> t
574         50
575         ├──25
576         │  └──22
577         │     ├──13
578         │     └──24
579         └──75
580            └──66
581
582         >>> for value in t.iterate_inorder():
583         ...     print(value)
584         13
585         22
586         24
587         25
588         50
589         66
590         75
591
592         """
593         if self.root is not None:
594             yield from self._iterate_inorder(self.root)
595
596     def iterate_postorder(self):
597         """
598         Returns:
599             A Generator that yield the tree's items in a preorder
600             traversal sequence.
601
602         >>> t = BinarySearchTree()
603         >>> t.insert(50)
604         >>> t.insert(75)
605         >>> t.insert(25)
606         >>> t.insert(66)
607         >>> t.insert(22)
608         >>> t.insert(13)
609
610         >>> for value in t.iterate_postorder():
611         ...     print(value)
612         13
613         22
614         25
615         66
616         75
617         50
618
619         """
620         if self.root is not None:
621             yield from self._iterate_postorder(self.root)
622
623     def _iterate_leaves(self, node: Node):
624         if node.left is not None:
625             yield from self._iterate_leaves(node.left)
626         if node.right is not None:
627             yield from self._iterate_leaves(node.right)
628         if node.left is None and node.right is None:
629             yield node.value
630
631     def iterate_leaves(self):
632         """
633         Returns:
634             A Gemerator that yielde only the leaf nodes in the
635             tree.
636
637         >>> t = BinarySearchTree()
638         >>> t.insert(50)
639         >>> t.insert(75)
640         >>> t.insert(25)
641         >>> t.insert(66)
642         >>> t.insert(22)
643         >>> t.insert(13)
644
645         >>> for value in t.iterate_leaves():
646         ...     print(value)
647         13
648         66
649
650         """
651         if self.root is not None:
652             yield from self._iterate_leaves(self.root)
653
654     def _iterate_by_depth(self, node: Node, depth: int):
655         if depth == 0:
656             yield node.value
657         else:
658             assert depth > 0
659             if node.left is not None:
660                 yield from self._iterate_by_depth(node.left, depth - 1)
661             if node.right is not None:
662                 yield from self._iterate_by_depth(node.right, depth - 1)
663
664     def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
665         """
666         Args:
667             depth: the desired depth
668
669         Returns:
670             A Generator that yields nodes at the prescribed depth in
671             the tree.
672
673         >>> t = BinarySearchTree()
674         >>> t.insert(50)
675         >>> t.insert(75)
676         >>> t.insert(25)
677         >>> t.insert(66)
678         >>> t.insert(22)
679         >>> t.insert(13)
680
681         >>> for value in t.iterate_nodes_by_depth(2):
682         ...     print(value)
683         22
684         66
685
686         >>> for value in t.iterate_nodes_by_depth(3):
687         ...     print(value)
688         13
689
690         """
691         if self.root is not None:
692             yield from self._iterate_by_depth(self.root, depth)
693
694     def get_next_node(self, node: Node) -> Optional[Node]:
695         """
696         Args:
697             node: the node whose next greater successor is desired
698
699         Returns:
700             Given a tree node, returns the next greater node in the tree.
701             If the given node is the greatest node in the tree, returns None.
702
703         >>> t = BinarySearchTree()
704         >>> t.insert(50)
705         >>> t.insert(75)
706         >>> t.insert(25)
707         >>> t.insert(66)
708         >>> t.insert(22)
709         >>> t.insert(13)
710         >>> t.insert(23)
711         >>> t
712         50
713         ├──25
714         │  └──22
715         │     ├──13
716         │     └──23
717         └──75
718            └──66
719
720         >>> n = t[23]
721         >>> t.get_next_node(n).value
722         25
723
724         >>> n = t[50]
725         >>> t.get_next_node(n).value
726         66
727
728         >>> n = t[75]
729         >>> t.get_next_node(n) is None
730         True
731
732         """
733         if node.right is not None:
734             x = node.right
735             while x.left is not None:
736                 x = x.left
737             return x
738
739         path = self.parent_path(node)
740         assert path[-1] is not None
741         assert path[-1] == node
742         path = path[:-1]
743         path.reverse()
744         for ancestor in path:
745             assert ancestor is not None
746             if node != ancestor.right:
747                 return ancestor
748             node = ancestor
749         return None
750
751     def get_nodes_in_range_inclusive(
752         self, lower: ComparableNodeValue, upper: ComparableNodeValue
753     ):
754         """
755         >>> t = BinarySearchTree()
756         >>> t.insert(50)
757         >>> t.insert(75)
758         >>> t.insert(25)
759         >>> t.insert(66)
760         >>> t.insert(22)
761         >>> t.insert(13)
762         >>> t.insert(23)
763         >>> t
764         50
765         ├──25
766         │  └──22
767         │     ├──13
768         │     └──23
769         └──75
770            └──66
771
772         >>> for node in t.get_nodes_in_range_inclusive(21, 74):
773         ...     print(node.value)
774         22
775         23
776         25
777         50
778         66
779         """
780         node: Optional[Node] = self._find_lowest_node_greater_than_or_equal_to(
781             lower, self.root
782         )
783         while node:
784             if lower <= node.value <= upper:
785                 yield node
786             node = self.get_next_node(node)
787
788     def _depth(self, node: Node, sofar: int) -> int:
789         depth_left = sofar + 1
790         depth_right = sofar + 1
791         if node.left is not None:
792             depth_left = self._depth(node.left, sofar + 1)
793         if node.right is not None:
794             depth_right = self._depth(node.right, sofar + 1)
795         return max(depth_left, depth_right)
796
797     def depth(self) -> int:
798         """
799         Returns:
800             The max height (depth) of the tree in plies (edge distance
801             from root).
802
803         >>> t = BinarySearchTree()
804         >>> t.depth()
805         0
806
807         >>> t.insert(50)
808         >>> t.depth()
809         1
810
811         >>> t.insert(65)
812         >>> t.depth()
813         2
814
815         >>> t.insert(33)
816         >>> t.depth()
817         2
818
819         >>> t.insert(2)
820         >>> t.insert(1)
821         >>> t.depth()
822         4
823
824         """
825         if self.root is None:
826             return 0
827         return self._depth(self.root, 0)
828
829     def height(self) -> int:
830         """Returns the height (i.e. max depth) of the tree"""
831         return self.depth()
832
833     def repr_traverse(
834         self,
835         padding: str,
836         pointer: str,
837         node: Optional[Node],
838         has_right_sibling: bool,
839     ) -> str:
840         if node is not None:
841             viz = f"\n{padding}{pointer}{node.value}"
842             if has_right_sibling:
843                 padding += "│  "
844             else:
845                 padding += "   "
846
847             pointer_right = "└──"
848             if node.right is not None:
849                 pointer_left = "├──"
850             else:
851                 pointer_left = "└──"
852
853             viz += self.repr_traverse(
854                 padding, pointer_left, node.left, node.right is not None
855             )
856             viz += self.repr_traverse(padding, pointer_right, node.right, False)
857             return viz
858         return ""
859
860     def __repr__(self):
861         """
862         Returns:
863             An ASCII string representation of the tree.
864
865         >>> t = BinarySearchTree()
866         >>> t.insert(50)
867         >>> t.insert(25)
868         >>> t.insert(75)
869         >>> t.insert(12)
870         >>> t.insert(33)
871         >>> t.insert(88)
872         >>> t.insert(55)
873         >>> t
874         50
875         ├──25
876         │  ├──12
877         │  └──33
878         └──75
879            ├──55
880            └──88
881         """
882         if self.root is None:
883             return ""
884
885         ret = f"{self.root.value}"
886         pointer_right = "└──"
887         if self.root.right is None:
888             pointer_left = "└──"
889         else:
890             pointer_left = "├──"
891
892         ret += self.repr_traverse(
893             "", pointer_left, self.root.left, self.root.left is not None
894         )
895         ret += self.repr_traverse("", pointer_right, self.root.right, False)
896         return ret
897
898
899 if __name__ == "__main__":
900     import doctest
901
902     doctest.testmod()