Dots instead of pass...
[pyutils.git] / src / pyutils / collectionz / bst.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2023, Scott Gasch
4
5 """A binary search tree implementation."""
6
7 from abc import abstractmethod
8 from typing import Any, Generator, List, Optional, Protocol
9
10
11 class Comparable(Protocol):
12     @abstractmethod
13     def __lt__(self, other: Any) -> bool:
14         ...
15
16     @abstractmethod
17     def __le__(self, other: Any) -> bool:
18         ...
19
20     @abstractmethod
21     def __eq__(self, other: Any) -> bool:
22         ...
23
24
25 class Node:
26     def __init__(self, value: Comparable) -> None:
27         """A BST node.  Note that value can be anything as long as it
28         is comparable with other instances of itself.  Check out
29         :meth:`functools.total_ordering`
30         (https://docs.python.org/3/library/functools.html#functools.total_ordering)
31
32         Args:
33             value: a reference to the value of the node.
34
35         """
36         self.left: Optional[Node] = None
37         self.right: Optional[Node] = None
38         self.value: Comparable = value
39
40
41 class BinarySearchTree(object):
42     def __init__(self):
43         self.root = None
44         self.count = 0
45         self.traverse = None
46
47     def get_root(self) -> Optional[Node]:
48         """
49         Returns:
50             The root of the BST
51         """
52
53         return self.root
54
55     def _on_insert(self, parent: Optional[Node], new: Node) -> None:
56         """This is called immediately _after_ a new node is inserted."""
57         pass
58
59     def insert(self, value: Comparable) -> None:
60         """
61         Insert something into the tree.
62
63         Args:
64             value: the value to be inserted.
65
66         >>> t = BinarySearchTree()
67         >>> t.insert(10)
68         >>> t.insert(20)
69         >>> t.insert(5)
70         >>> len(t)
71         3
72
73         >>> t.get_root().value
74         10
75
76         """
77         if self.root is None:
78             self.root = Node(value)
79             self.count = 1
80             self._on_insert(None, self.root)
81         else:
82             self._insert(value, self.root)
83
84     def _insert(self, value: Comparable, node: Node):
85         """Insertion helper"""
86         if value < node.value:
87             if node.left is not None:
88                 self._insert(value, node.left)
89             else:
90                 node.left = Node(value)
91                 self.count += 1
92                 self._on_insert(node, node.left)
93         else:
94             if node.right is not None:
95                 self._insert(value, node.right)
96             else:
97                 node.right = Node(value)
98                 self.count += 1
99                 self._on_insert(node, node.right)
100
101     def __getitem__(self, value: Comparable) -> Optional[Node]:
102         """
103         Find an item in the tree and return its Node.  Returns
104         None if the item is not in the tree.
105
106         >>> t = BinarySearchTree()
107         >>> t[99]
108
109         >>> t.insert(10)
110         >>> t.insert(20)
111         >>> t.insert(5)
112         >>> t[10].value
113         10
114
115         >>> t[99]
116
117         """
118         if self.root is not None:
119             return self._find_exact(value, self.root)
120         return None
121
122     def _find_exact(self, target: Comparable, node: Node) -> Optional[Node]:
123         """Recursively traverse the tree looking for a node with the
124         target value.  Return that node if it exists, otherwise return
125         None."""
126
127         if target == node.value:
128             return node
129         elif target < node.value and node.left is not None:
130             return self._find_exact(target, node.left)
131         elif target > node.value and node.right is not None:
132             return self._find_exact(target, node.right)
133         return None
134
135     def _find_lowest_node_less_than_or_equal_to(
136         self, target: Comparable, node: Optional[Node]
137     ) -> Optional[Node]:
138         """Find helper that returns the lowest node that is less
139         than or equal to the target value.  Returns None if target is
140         lower than the lowest node in the tree.
141
142         >>> t = BinarySearchTree()
143         >>> t.insert(50)
144         >>> t.insert(75)
145         >>> t.insert(25)
146         >>> t.insert(66)
147         >>> t.insert(22)
148         >>> t.insert(13)
149         >>> t.insert(85)
150         >>> t
151         50
152         ├──25
153         │  └──22
154         │     └──13
155         └──75
156            ├──66
157            └──85
158
159         >>> t._find_lowest_node_less_than_or_equal_to(48, t.root).value
160         25
161         >>> t._find_lowest_node_less_than_or_equal_to(55, t.root).value
162         50
163         >>> t._find_lowest_node_less_than_or_equal_to(100, t.root).value
164         85
165         >>> t._find_lowest_node_less_than_or_equal_to(24, t.root).value
166         22
167         >>> t._find_lowest_node_less_than_or_equal_to(20, t.root).value
168         13
169         >>> t._find_lowest_node_less_than_or_equal_to(72, t.root).value
170         66
171         >>> t._find_lowest_node_less_than_or_equal_to(78, t.root).value
172         75
173         >>> t._find_lowest_node_less_than_or_equal_to(12, t.root) is None
174         True
175
176         """
177
178         if not node:
179             return None
180
181         if target == node.value:
182             return node
183
184         elif target > node.value:
185             if below := self._find_lowest_node_less_than_or_equal_to(
186                 target, node.right
187             ):
188                 return below
189             else:
190                 return node
191
192         else:
193             return self._find_lowest_node_less_than_or_equal_to(target, node.left)
194
195     def _find_lowest_node_greater_than_or_equal_to(
196         self, target: Comparable, node: Optional[Node]
197     ) -> Optional[Node]:
198         """Find helper that returns the lowest node that is greater
199         than or equal to the target value.  Returns None if target is
200         higher than the greatest node in the tree.
201
202         >>> t = BinarySearchTree()
203         >>> t.insert(50)
204         >>> t.insert(75)
205         >>> t.insert(25)
206         >>> t.insert(66)
207         >>> t.insert(22)
208         >>> t.insert(13)
209         >>> t.insert(85)
210         >>> t
211         50
212         ├──25
213         │  └──22
214         │     └──13
215         └──75
216            ├──66
217            └──85
218
219         >>> t._find_lowest_node_greater_than_or_equal_to(48, t.root).value
220         50
221         >>> t._find_lowest_node_greater_than_or_equal_to(55, t.root).value
222         66
223         >>> t._find_lowest_node_greater_than_or_equal_to(1, t.root).value
224         13
225         >>> t._find_lowest_node_greater_than_or_equal_to(24, t.root).value
226         25
227         >>> t._find_lowest_node_greater_than_or_equal_to(20, t.root).value
228         22
229         >>> t._find_lowest_node_greater_than_or_equal_to(72, t.root).value
230         75
231         >>> t._find_lowest_node_greater_than_or_equal_to(78, t.root).value
232         85
233         >>> t._find_lowest_node_greater_than_or_equal_to(95, t.root) is None
234         True
235
236         """
237
238         if not node:
239             return None
240
241         if target == node.value:
242             return node
243
244         elif target > node.value:
245             return self._find_lowest_node_greater_than_or_equal_to(target, node.right)
246
247         # If target < this node's value, either this node is the
248         # answer or the answer is in this node's left subtree.
249         else:
250             if below := self._find_lowest_node_greater_than_or_equal_to(
251                 target, node.left
252             ):
253                 return below
254             else:
255                 return node
256
257     def _parent_path(
258         self, current: Optional[Node], target: Node
259     ) -> List[Optional[Node]]:
260         """Internal helper"""
261         if current is None:
262             return [None]
263         ret: List[Optional[Node]] = [current]
264         if target.value == current.value:
265             return ret
266         elif target.value < current.value:
267             ret.extend(self._parent_path(current.left, target))
268             return ret
269         else:
270             assert target.value > current.value
271             ret.extend(self._parent_path(current.right, target))
272             return ret
273
274     def parent_path(self, node: Node) -> List[Optional[Node]]:
275         """Get a node's parent path.
276
277         Args:
278             node: the node to check
279
280         Returns:
281             a list of nodes representing the path from
282             the tree's root to the node.
283
284         .. note::
285
286             If the node does not exist in the tree, the last element
287             on the path will be None but the path will indicate the
288             ancestor path of that node were it to be inserted.
289
290         >>> t = BinarySearchTree()
291         >>> t.insert(50)
292         >>> t.insert(75)
293         >>> t.insert(25)
294         >>> t.insert(12)
295         >>> t.insert(33)
296         >>> t.insert(4)
297         >>> t.insert(88)
298         >>> t
299         50
300         ├──25
301         │  ├──12
302         │  │  └──4
303         │  └──33
304         └──75
305            └──88
306
307         >>> n = t[4]
308         >>> for x in t.parent_path(n):
309         ...     print(x.value)
310         50
311         25
312         12
313         4
314
315         >>> del t[4]
316         >>> for x in t.parent_path(n):
317         ...     if x is not None:
318         ...         print(x.value)
319         ...     else:
320         ...         print(x)
321         50
322         25
323         12
324         None
325
326         """
327         return self._parent_path(self.root, node)
328
329     def __delitem__(self, value: Comparable) -> bool:
330         """
331         Delete an item from the tree and preserve the BST property.
332
333         Args:
334             value: the value of the node to be deleted.
335
336         Returns:
337             True if the value was found and its associated node was
338             successfully deleted and False otherwise.
339
340         >>> t = BinarySearchTree()
341         >>> t.insert(50)
342         >>> t.insert(75)
343         >>> t.insert(25)
344         >>> t.insert(66)
345         >>> t.insert(22)
346         >>> t.insert(13)
347         >>> t.insert(85)
348         >>> t
349         50
350         ├──25
351         │  └──22
352         │     └──13
353         └──75
354            ├──66
355            └──85
356
357         >>> for value in t.iterate_inorder():
358         ...     print(value)
359         13
360         22
361         25
362         50
363         66
364         75
365         85
366
367         >>> del t[22]  # Note: bool result is discarded
368
369         >>> for value in t.iterate_inorder():
370         ...     print(value)
371         13
372         25
373         50
374         66
375         75
376         85
377
378         >>> t.__delitem__(13)
379         True
380         >>> for value in t.iterate_inorder():
381         ...     print(value)
382         25
383         50
384         66
385         75
386         85
387
388         >>> t.__delitem__(75)
389         True
390         >>> for value in t.iterate_inorder():
391         ...     print(value)
392         25
393         50
394         66
395         85
396         >>> t
397         50
398         ├──25
399         └──85
400            └──66
401
402         >>> t.__delitem__(99)
403         False
404
405         """
406         if self.root is not None:
407             ret = self._delete(value, None, self.root)
408             if ret:
409                 self.count -= 1
410                 if self.count == 0:
411                     self.root = None
412             return ret
413         return False
414
415     def _on_delete(self, parent: Optional[Node], deleted: Node) -> None:
416         """This is called just after deleted was deleted from the tree"""
417         pass
418
419     def _delete(self, value: Comparable, parent: Optional[Node], node: Node) -> bool:
420         """Delete helper"""
421         if node.value == value:
422
423             # Deleting a leaf node
424             if node.left is None and node.right is None:
425                 if parent is not None:
426                     if parent.left == node:
427                         parent.left = None
428                     else:
429                         assert parent.right == node
430                         parent.right = None
431                 self._on_delete(parent, node)
432                 return True
433
434             # Node only has a right.
435             elif node.left is None:
436                 assert node.right is not None
437                 if parent is not None:
438                     if parent.left == node:
439                         parent.left = node.right
440                     else:
441                         assert parent.right == node
442                         parent.right = node.right
443                 self._on_delete(parent, node)
444                 return True
445
446             # Node only has a left.
447             elif node.right is None:
448                 assert node.left is not None
449                 if parent is not None:
450                     if parent.left == node:
451                         parent.left = node.left
452                     else:
453                         assert parent.right == node
454                         parent.right = node.left
455                 self._on_delete(parent, node)
456                 return True
457
458             # Node has both a left and right.
459             else:
460                 assert node.left is not None and node.right is not None
461                 descendent = node.right
462                 while descendent.left is not None:
463                     descendent = descendent.left
464                 node.value = descendent.value
465                 return self._delete(node.value, node, node.right)
466         elif value < node.value and node.left is not None:
467             return self._delete(value, node, node.left)
468         elif value > node.value and node.right is not None:
469             return self._delete(value, node, node.right)
470         return False
471
472     def __len__(self):
473         """
474         Returns:
475             The count of items in the tree.
476
477         >>> t = BinarySearchTree()
478         >>> len(t)
479         0
480         >>> t.insert(50)
481         >>> len(t)
482         1
483         >>> t.__delitem__(50)
484         True
485         >>> len(t)
486         0
487         >>> t.insert(75)
488         >>> t.insert(25)
489         >>> t.insert(66)
490         >>> t.insert(22)
491         >>> t.insert(13)
492         >>> t.insert(85)
493         >>> len(t)
494         6
495
496         """
497         return self.count
498
499     def __contains__(self, value: Comparable) -> bool:
500         """
501         Returns:
502             True if the item is in the tree; False otherwise.
503         """
504         return self.__getitem__(value) is not None
505
506     def _iterate_preorder(self, node: Node):
507         yield node.value
508         if node.left is not None:
509             yield from self._iterate_preorder(node.left)
510         if node.right is not None:
511             yield from self._iterate_preorder(node.right)
512
513     def _iterate_inorder(self, node: Node):
514         if node.left is not None:
515             yield from self._iterate_inorder(node.left)
516         yield node.value
517         if node.right is not None:
518             yield from self._iterate_inorder(node.right)
519
520     def _iterate_postorder(self, node: Node):
521         if node.left is not None:
522             yield from self._iterate_postorder(node.left)
523         if node.right is not None:
524             yield from self._iterate_postorder(node.right)
525         yield node.value
526
527     def iterate_preorder(self):
528         """
529         Returns:
530             A Generator that yields the tree's items in a
531             preorder traversal sequence.
532
533         >>> t = BinarySearchTree()
534         >>> t.insert(50)
535         >>> t.insert(75)
536         >>> t.insert(25)
537         >>> t.insert(66)
538         >>> t.insert(22)
539         >>> t.insert(13)
540
541         >>> for value in t.iterate_preorder():
542         ...     print(value)
543         50
544         25
545         22
546         13
547         75
548         66
549
550         """
551         if self.root is not None:
552             yield from self._iterate_preorder(self.root)
553
554     def iterate_inorder(self):
555         """
556         Returns:
557             A Generator that yield the tree's items in a preorder
558             traversal sequence.
559
560         >>> t = BinarySearchTree()
561         >>> t.insert(50)
562         >>> t.insert(75)
563         >>> t.insert(25)
564         >>> t.insert(66)
565         >>> t.insert(22)
566         >>> t.insert(13)
567         >>> t.insert(24)
568         >>> t
569         50
570         ├──25
571         │  └──22
572         │     ├──13
573         │     └──24
574         └──75
575            └──66
576
577         >>> for value in t.iterate_inorder():
578         ...     print(value)
579         13
580         22
581         24
582         25
583         50
584         66
585         75
586
587         """
588         if self.root is not None:
589             yield from self._iterate_inorder(self.root)
590
591     def iterate_postorder(self):
592         """
593         Returns:
594             A Generator that yield the tree's items in a preorder
595             traversal sequence.
596
597         >>> t = BinarySearchTree()
598         >>> t.insert(50)
599         >>> t.insert(75)
600         >>> t.insert(25)
601         >>> t.insert(66)
602         >>> t.insert(22)
603         >>> t.insert(13)
604
605         >>> for value in t.iterate_postorder():
606         ...     print(value)
607         13
608         22
609         25
610         66
611         75
612         50
613
614         """
615         if self.root is not None:
616             yield from self._iterate_postorder(self.root)
617
618     def _iterate_leaves(self, node: Node):
619         if node.left is not None:
620             yield from self._iterate_leaves(node.left)
621         if node.right is not None:
622             yield from self._iterate_leaves(node.right)
623         if node.left is None and node.right is None:
624             yield node.value
625
626     def iterate_leaves(self):
627         """
628         Returns:
629             A Gemerator that yielde only the leaf nodes in the
630             tree.
631
632         >>> t = BinarySearchTree()
633         >>> t.insert(50)
634         >>> t.insert(75)
635         >>> t.insert(25)
636         >>> t.insert(66)
637         >>> t.insert(22)
638         >>> t.insert(13)
639
640         >>> for value in t.iterate_leaves():
641         ...     print(value)
642         13
643         66
644
645         """
646         if self.root is not None:
647             yield from self._iterate_leaves(self.root)
648
649     def _iterate_by_depth(self, node: Node, depth: int):
650         if depth == 0:
651             yield node.value
652         else:
653             assert depth > 0
654             if node.left is not None:
655                 yield from self._iterate_by_depth(node.left, depth - 1)
656             if node.right is not None:
657                 yield from self._iterate_by_depth(node.right, depth - 1)
658
659     def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
660         """
661         Args:
662             depth: the desired depth
663
664         Returns:
665             A Generator that yields nodes at the prescribed depth in
666             the tree.
667
668         >>> t = BinarySearchTree()
669         >>> t.insert(50)
670         >>> t.insert(75)
671         >>> t.insert(25)
672         >>> t.insert(66)
673         >>> t.insert(22)
674         >>> t.insert(13)
675
676         >>> for value in t.iterate_nodes_by_depth(2):
677         ...     print(value)
678         22
679         66
680
681         >>> for value in t.iterate_nodes_by_depth(3):
682         ...     print(value)
683         13
684
685         """
686         if self.root is not None:
687             yield from self._iterate_by_depth(self.root, depth)
688
689     def get_next_node(self, node: Node) -> Optional[Node]:
690         """
691         Args:
692             node: the node whose next greater successor is desired
693
694         Returns:
695             Given a tree node, returns the next greater node in the tree.
696             If the given node is the greatest node in the tree, returns None.
697
698         >>> t = BinarySearchTree()
699         >>> t.insert(50)
700         >>> t.insert(75)
701         >>> t.insert(25)
702         >>> t.insert(66)
703         >>> t.insert(22)
704         >>> t.insert(13)
705         >>> t.insert(23)
706         >>> t
707         50
708         ├──25
709         │  └──22
710         │     ├──13
711         │     └──23
712         └──75
713            └──66
714
715         >>> n = t[23]
716         >>> t.get_next_node(n).value
717         25
718
719         >>> n = t[50]
720         >>> t.get_next_node(n).value
721         66
722
723         >>> n = t[75]
724         >>> t.get_next_node(n) is None
725         True
726
727         """
728         if node.right is not None:
729             x = node.right
730             while x.left is not None:
731                 x = x.left
732             return x
733
734         path = self.parent_path(node)
735         assert path[-1] is not None
736         assert path[-1] == node
737         path = path[:-1]
738         path.reverse()
739         for ancestor in path:
740             assert ancestor is not None
741             if node != ancestor.right:
742                 return ancestor
743             node = ancestor
744         return None
745
746     def get_nodes_in_range_inclusive(self, lower: Comparable, upper: Comparable):
747         """
748         >>> t = BinarySearchTree()
749         >>> t.insert(50)
750         >>> t.insert(75)
751         >>> t.insert(25)
752         >>> t.insert(66)
753         >>> t.insert(22)
754         >>> t.insert(13)
755         >>> t.insert(23)
756         >>> t
757         50
758         ├──25
759         │  └──22
760         │     ├──13
761         │     └──23
762         └──75
763            └──66
764
765         >>> for node in t.get_nodes_in_range_inclusive(21, 74):
766         ...     print(node.value)
767         22
768         23
769         25
770         50
771         66
772         """
773         node: Optional[Node] = self._find_lowest_node_greater_than_or_equal_to(
774             lower, self.root
775         )
776         while node:
777             if lower <= node.value <= upper:
778                 yield node
779             node = self.get_next_node(node)
780
781     def _depth(self, node: Node, sofar: int) -> int:
782         depth_left = sofar + 1
783         depth_right = sofar + 1
784         if node.left is not None:
785             depth_left = self._depth(node.left, sofar + 1)
786         if node.right is not None:
787             depth_right = self._depth(node.right, sofar + 1)
788         return max(depth_left, depth_right)
789
790     def depth(self) -> int:
791         """
792         Returns:
793             The max height (depth) of the tree in plies (edge distance
794             from root).
795
796         >>> t = BinarySearchTree()
797         >>> t.depth()
798         0
799
800         >>> t.insert(50)
801         >>> t.depth()
802         1
803
804         >>> t.insert(65)
805         >>> t.depth()
806         2
807
808         >>> t.insert(33)
809         >>> t.depth()
810         2
811
812         >>> t.insert(2)
813         >>> t.insert(1)
814         >>> t.depth()
815         4
816
817         """
818         if self.root is None:
819             return 0
820         return self._depth(self.root, 0)
821
822     def height(self) -> int:
823         """Returns the height (i.e. max depth) of the tree"""
824         return self.depth()
825
826     def repr_traverse(
827         self,
828         padding: str,
829         pointer: str,
830         node: Optional[Node],
831         has_right_sibling: bool,
832     ) -> str:
833         if node is not None:
834             viz = f"\n{padding}{pointer}{node.value}"
835             if has_right_sibling:
836                 padding += "│  "
837             else:
838                 padding += "   "
839
840             pointer_right = "└──"
841             if node.right is not None:
842                 pointer_left = "├──"
843             else:
844                 pointer_left = "└──"
845
846             viz += self.repr_traverse(
847                 padding, pointer_left, node.left, node.right is not None
848             )
849             viz += self.repr_traverse(padding, pointer_right, node.right, False)
850             return viz
851         return ""
852
853     def __repr__(self):
854         """
855         Returns:
856             An ASCII string representation of the tree.
857
858         >>> t = BinarySearchTree()
859         >>> t.insert(50)
860         >>> t.insert(25)
861         >>> t.insert(75)
862         >>> t.insert(12)
863         >>> t.insert(33)
864         >>> t.insert(88)
865         >>> t.insert(55)
866         >>> t
867         50
868         ├──25
869         │  ├──12
870         │  └──33
871         └──75
872            ├──55
873            └──88
874         """
875         if self.root is None:
876             return ""
877
878         ret = f"{self.root.value}"
879         pointer_right = "└──"
880         if self.root.right is None:
881             pointer_left = "└──"
882         else:
883             pointer_left = "├──"
884
885         ret += self.repr_traverse(
886             "", pointer_left, self.root.left, self.root.left is not None
887         )
888         ret += self.repr_traverse("", pointer_right, self.root.right, False)
889         return ret
890
891
892 if __name__ == "__main__":
893     import doctest
894
895     doctest.testmod()