Cleanup code / comments.
[pyutils.git] / src / pyutils / collectionz / bst.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2023, Scott Gasch
4
5 """A binary search tree implementation."""
6
7 from abc import ABCMeta, abstractmethod
8 from typing import Any, Generator, List, Optional, TypeVar
9
10
11 class Comparable(metaclass=ABCMeta):
12     @abstractmethod
13     def __lt__(self, other: Any) -> bool:
14         pass
15
16     @abstractmethod
17     def __le__(self, other: Any) -> bool:
18         pass
19
20     @abstractmethod
21     def __eq__(self, other: Any) -> bool:
22         pass
23
24
25 ComparableNodeValue = TypeVar('ComparableNodeValue', bound=Comparable)
26
27
28 class Node(object):
29     def __init__(self, value: ComparableNodeValue) -> None:
30         """A BST node.  Note that value can be anything as long as it
31         is comparable with other instances of itself.  Check out
32         :meth:`functools.total_ordering`
33         (https://docs.python.org/3/library/functools.html#functools.total_ordering)
34
35         Args:
36             value: a reference to the value of the node.
37
38         """
39         self.left: Optional[Node] = None
40         self.right: Optional[Node] = None
41         self.value: ComparableNodeValue = value
42
43
44 class BinarySearchTree(object):
45     def __init__(self):
46         self.root = None
47         self.count = 0
48         self.traverse = None
49
50     def get_root(self) -> Optional[Node]:
51         """
52         Returns:
53             The root of the BST
54         """
55
56         return self.root
57
58     def _on_insert(self, parent: Optional[Node], new: Node) -> None:
59         """This is called immediately _after_ a new node is inserted."""
60         pass
61
62     def insert(self, value: ComparableNodeValue) -> None:
63         """
64         Insert something into the tree.
65
66         Args:
67             value: the value to be inserted.
68
69         >>> t = BinarySearchTree()
70         >>> t.insert(10)
71         >>> t.insert(20)
72         >>> t.insert(5)
73         >>> len(t)
74         3
75
76         >>> t.get_root().value
77         10
78
79         """
80         if self.root is None:
81             self.root = Node(value)
82             self.count = 1
83             self._on_insert(None, self.root)
84         else:
85             self._insert(value, self.root)
86
87     def _insert(self, value: ComparableNodeValue, node: Node):
88         """Insertion helper"""
89         if value < node.value:
90             if node.left is not None:
91                 self._insert(value, node.left)
92             else:
93                 node.left = Node(value)
94                 self.count += 1
95                 self._on_insert(node, node.left)
96         else:
97             if node.right is not None:
98                 self._insert(value, node.right)
99             else:
100                 node.right = Node(value)
101                 self.count += 1
102                 self._on_insert(node, node.right)
103
104     def __getitem__(self, value: ComparableNodeValue) -> Optional[Node]:
105         """
106         Find an item in the tree and return its Node.  Returns
107         None if the item is not in the tree.
108
109         >>> t = BinarySearchTree()
110         >>> t[99]
111
112         >>> t.insert(10)
113         >>> t.insert(20)
114         >>> t.insert(5)
115         >>> t[10].value
116         10
117
118         >>> t[99]
119
120         """
121         if self.root is not None:
122             return self._find(value, self.root)
123         return None
124
125     def _find(self, value: ComparableNodeValue, node: Node) -> Optional[Node]:
126         """Find helper"""
127         if value == node.value:
128             return node
129         elif value < node.value and node.left is not None:
130             return self._find(value, node.left)
131         elif value > node.value and node.right is not None:
132             return self._find(value, node.right)
133         return None
134
135     def _find_lowest_value_greater_than_or_equal_to(
136         self, target: ComparableNodeValue, node: Optional[Node]
137     ) -> Optional[Node]:
138         """Find helper that returns the lowest node that is greater
139         than or equal to the target value.  Returns None if target is
140         greater than the highest node in the tree.
141
142         >>> t = BinarySearchTree()
143         >>> t.insert(50)
144         >>> t.insert(75)
145         >>> t.insert(25)
146         >>> t.insert(66)
147         >>> t.insert(22)
148         >>> t.insert(13)
149         >>> t.insert(85)
150         >>> t
151         50
152         ├──25
153         │  └──22
154         │     └──13
155         └──75
156            ├──66
157            └──85
158
159         >>> t._find_lowest_value_greater_than_or_equal_to(48, t.root).value
160         50
161         >>> t._find_lowest_value_greater_than_or_equal_to(55, t.root).value
162         66
163         >>> t._find_lowest_value_greater_than_or_equal_to(1, t.root).value
164         13
165         >>> t._find_lowest_value_greater_than_or_equal_to(24, t.root).value
166         25
167         >>> t._find_lowest_value_greater_than_or_equal_to(20, t.root).value
168         22
169         >>> t._find_lowest_value_greater_than_or_equal_to(72, t.root).value
170         75
171         >>> t._find_lowest_value_greater_than_or_equal_to(78, t.root).value
172         85
173         >>> t._find_lowest_value_greater_than_or_equal_to(95, t.root) is None
174         True
175
176         """
177
178         if not node:
179             return None
180
181         if target == node.value:
182             return node
183
184         elif target > node.value:
185             return self._find_lowest_value_greater_than_or_equal_to(target, node.right)
186
187         # If target < this node's value, either this node is the
188         # answer or the answer is in this node's left subtree.
189         else:
190             if below := self._find_lowest_value_greater_than_or_equal_to(
191                 target, node.left
192             ):
193                 return below
194             else:
195                 return node
196
197     def _parent_path(
198         self, current: Optional[Node], target: Node
199     ) -> List[Optional[Node]]:
200         """Internal helper"""
201         if current is None:
202             return [None]
203         ret: List[Optional[Node]] = [current]
204         if target.value == current.value:
205             return ret
206         elif target.value < current.value:
207             ret.extend(self._parent_path(current.left, target))
208             return ret
209         else:
210             assert target.value > current.value
211             ret.extend(self._parent_path(current.right, target))
212             return ret
213
214     def parent_path(self, node: Node) -> List[Optional[Node]]:
215         """Get a node's parent path.
216
217         Args:
218             node: the node to check
219
220         Returns:
221             a list of nodes representing the path from
222             the tree's root to the node.
223
224         .. note::
225
226             If the node does not exist in the tree, the last element
227             on the path will be None but the path will indicate the
228             ancestor path of that node were it to be inserted.
229
230         >>> t = BinarySearchTree()
231         >>> t.insert(50)
232         >>> t.insert(75)
233         >>> t.insert(25)
234         >>> t.insert(12)
235         >>> t.insert(33)
236         >>> t.insert(4)
237         >>> t.insert(88)
238         >>> t
239         50
240         ├──25
241         │  ├──12
242         │  │  └──4
243         │  └──33
244         └──75
245            └──88
246
247         >>> n = t[4]
248         >>> for x in t.parent_path(n):
249         ...     print(x.value)
250         50
251         25
252         12
253         4
254
255         >>> del t[4]
256         >>> for x in t.parent_path(n):
257         ...     if x is not None:
258         ...         print(x.value)
259         ...     else:
260         ...         print(x)
261         50
262         25
263         12
264         None
265
266         """
267         return self._parent_path(self.root, node)
268
269     def __delitem__(self, value: ComparableNodeValue) -> bool:
270         """
271         Delete an item from the tree and preserve the BST property.
272
273         Args:
274             value: the value of the node to be deleted.
275
276         Returns:
277             True if the value was found and its associated node was
278             successfully deleted and False otherwise.
279
280         >>> t = BinarySearchTree()
281         >>> t.insert(50)
282         >>> t.insert(75)
283         >>> t.insert(25)
284         >>> t.insert(66)
285         >>> t.insert(22)
286         >>> t.insert(13)
287         >>> t.insert(85)
288         >>> t
289         50
290         ├──25
291         │  └──22
292         │     └──13
293         └──75
294            ├──66
295            └──85
296
297         >>> for value in t.iterate_inorder():
298         ...     print(value)
299         13
300         22
301         25
302         50
303         66
304         75
305         85
306
307         >>> del t[22]  # Note: bool result is discarded
308
309         >>> for value in t.iterate_inorder():
310         ...     print(value)
311         13
312         25
313         50
314         66
315         75
316         85
317
318         >>> t.__delitem__(13)
319         True
320         >>> for value in t.iterate_inorder():
321         ...     print(value)
322         25
323         50
324         66
325         75
326         85
327
328         >>> t.__delitem__(75)
329         True
330         >>> for value in t.iterate_inorder():
331         ...     print(value)
332         25
333         50
334         66
335         85
336         >>> t
337         50
338         ├──25
339         └──85
340            └──66
341
342         >>> t.__delitem__(99)
343         False
344
345         """
346         if self.root is not None:
347             ret = self._delete(value, None, self.root)
348             if ret:
349                 self.count -= 1
350                 if self.count == 0:
351                     self.root = None
352             return ret
353         return False
354
355     def _on_delete(self, parent: Optional[Node], deleted: Node) -> None:
356         """This is called just after deleted was deleted from the tree"""
357         pass
358
359     def _delete(
360         self, value: ComparableNodeValue, parent: Optional[Node], node: Node
361     ) -> bool:
362         """Delete helper"""
363         if node.value == value:
364
365             # Deleting a leaf node
366             if node.left is None and node.right is None:
367                 if parent is not None:
368                     if parent.left == node:
369                         parent.left = None
370                     else:
371                         assert parent.right == node
372                         parent.right = None
373                 self._on_delete(parent, node)
374                 return True
375
376             # Node only has a right.
377             elif node.left is None:
378                 assert node.right is not None
379                 if parent is not None:
380                     if parent.left == node:
381                         parent.left = node.right
382                     else:
383                         assert parent.right == node
384                         parent.right = node.right
385                 self._on_delete(parent, node)
386                 return True
387
388             # Node only has a left.
389             elif node.right is None:
390                 assert node.left is not None
391                 if parent is not None:
392                     if parent.left == node:
393                         parent.left = node.left
394                     else:
395                         assert parent.right == node
396                         parent.right = node.left
397                 self._on_delete(parent, node)
398                 return True
399
400             # Node has both a left and right.
401             else:
402                 assert node.left is not None and node.right is not None
403                 descendent = node.right
404                 while descendent.left is not None:
405                     descendent = descendent.left
406                 node.value = descendent.value
407                 return self._delete(node.value, node, node.right)
408         elif value < node.value and node.left is not None:
409             return self._delete(value, node, node.left)
410         elif value > node.value and node.right is not None:
411             return self._delete(value, node, node.right)
412         return False
413
414     def __len__(self):
415         """
416         Returns:
417             The count of items in the tree.
418
419         >>> t = BinarySearchTree()
420         >>> len(t)
421         0
422         >>> t.insert(50)
423         >>> len(t)
424         1
425         >>> t.__delitem__(50)
426         True
427         >>> len(t)
428         0
429         >>> t.insert(75)
430         >>> t.insert(25)
431         >>> t.insert(66)
432         >>> t.insert(22)
433         >>> t.insert(13)
434         >>> t.insert(85)
435         >>> len(t)
436         6
437
438         """
439         return self.count
440
441     def __contains__(self, value: ComparableNodeValue) -> bool:
442         """
443         Returns:
444             True if the item is in the tree; False otherwise.
445         """
446         return self.__getitem__(value) is not None
447
448     def _iterate_preorder(self, node: Node):
449         yield node.value
450         if node.left is not None:
451             yield from self._iterate_preorder(node.left)
452         if node.right is not None:
453             yield from self._iterate_preorder(node.right)
454
455     def _iterate_inorder(self, node: Node):
456         if node.left is not None:
457             yield from self._iterate_inorder(node.left)
458         yield node.value
459         if node.right is not None:
460             yield from self._iterate_inorder(node.right)
461
462     def _iterate_postorder(self, node: Node):
463         if node.left is not None:
464             yield from self._iterate_postorder(node.left)
465         if node.right is not None:
466             yield from self._iterate_postorder(node.right)
467         yield node.value
468
469     def iterate_preorder(self):
470         """
471         Returns:
472             A Generator that yields the tree's items in a
473             preorder traversal sequence.
474
475         >>> t = BinarySearchTree()
476         >>> t.insert(50)
477         >>> t.insert(75)
478         >>> t.insert(25)
479         >>> t.insert(66)
480         >>> t.insert(22)
481         >>> t.insert(13)
482
483         >>> for value in t.iterate_preorder():
484         ...     print(value)
485         50
486         25
487         22
488         13
489         75
490         66
491
492         """
493         if self.root is not None:
494             yield from self._iterate_preorder(self.root)
495
496     def iterate_inorder(self):
497         """
498         Returns:
499             A Generator that yield the tree's items in a preorder
500             traversal sequence.
501
502         >>> t = BinarySearchTree()
503         >>> t.insert(50)
504         >>> t.insert(75)
505         >>> t.insert(25)
506         >>> t.insert(66)
507         >>> t.insert(22)
508         >>> t.insert(13)
509         >>> t.insert(24)
510         >>> t
511         50
512         ├──25
513         │  └──22
514         │     ├──13
515         │     └──24
516         └──75
517            └──66
518
519         >>> for value in t.iterate_inorder():
520         ...     print(value)
521         13
522         22
523         24
524         25
525         50
526         66
527         75
528
529         """
530         if self.root is not None:
531             yield from self._iterate_inorder(self.root)
532
533     def iterate_postorder(self):
534         """
535         Returns:
536             A Generator that yield the tree's items in a preorder
537             traversal sequence.
538
539         >>> t = BinarySearchTree()
540         >>> t.insert(50)
541         >>> t.insert(75)
542         >>> t.insert(25)
543         >>> t.insert(66)
544         >>> t.insert(22)
545         >>> t.insert(13)
546
547         >>> for value in t.iterate_postorder():
548         ...     print(value)
549         13
550         22
551         25
552         66
553         75
554         50
555
556         """
557         if self.root is not None:
558             yield from self._iterate_postorder(self.root)
559
560     def _iterate_leaves(self, node: Node):
561         if node.left is not None:
562             yield from self._iterate_leaves(node.left)
563         if node.right is not None:
564             yield from self._iterate_leaves(node.right)
565         if node.left is None and node.right is None:
566             yield node.value
567
568     def iterate_leaves(self):
569         """
570         Returns:
571             A Gemerator that yielde only the leaf nodes in the
572             tree.
573
574         >>> t = BinarySearchTree()
575         >>> t.insert(50)
576         >>> t.insert(75)
577         >>> t.insert(25)
578         >>> t.insert(66)
579         >>> t.insert(22)
580         >>> t.insert(13)
581
582         >>> for value in t.iterate_leaves():
583         ...     print(value)
584         13
585         66
586
587         """
588         if self.root is not None:
589             yield from self._iterate_leaves(self.root)
590
591     def _iterate_by_depth(self, node: Node, depth: int):
592         if depth == 0:
593             yield node.value
594         else:
595             assert depth > 0
596             if node.left is not None:
597                 yield from self._iterate_by_depth(node.left, depth - 1)
598             if node.right is not None:
599                 yield from self._iterate_by_depth(node.right, depth - 1)
600
601     def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
602         """
603         Args:
604             depth: the desired depth
605
606         Returns:
607             A Generator that yields nodes at the prescribed depth in
608             the tree.
609
610         >>> t = BinarySearchTree()
611         >>> t.insert(50)
612         >>> t.insert(75)
613         >>> t.insert(25)
614         >>> t.insert(66)
615         >>> t.insert(22)
616         >>> t.insert(13)
617
618         >>> for value in t.iterate_nodes_by_depth(2):
619         ...     print(value)
620         22
621         66
622
623         >>> for value in t.iterate_nodes_by_depth(3):
624         ...     print(value)
625         13
626
627         """
628         if self.root is not None:
629             yield from self._iterate_by_depth(self.root, depth)
630
631     def get_next_node(self, node: Node) -> Optional[Node]:
632         """
633         Args:
634             node: the node whose next greater successor is desired
635
636         Returns:
637             Given a tree node, returns the next greater node in the tree.
638             If the given node is the greatest node in the tree, returns None.
639
640         >>> t = BinarySearchTree()
641         >>> t.insert(50)
642         >>> t.insert(75)
643         >>> t.insert(25)
644         >>> t.insert(66)
645         >>> t.insert(22)
646         >>> t.insert(13)
647         >>> t.insert(23)
648         >>> t
649         50
650         ├──25
651         │  └──22
652         │     ├──13
653         │     └──23
654         └──75
655            └──66
656
657         >>> n = t[23]
658         >>> t.get_next_node(n).value
659         25
660
661         >>> n = t[50]
662         >>> t.get_next_node(n).value
663         66
664
665         >>> n = t[75]
666         >>> t.get_next_node(n) is None
667         True
668
669         """
670         if node.right is not None:
671             x = node.right
672             while x.left is not None:
673                 x = x.left
674             return x
675
676         path = self.parent_path(node)
677         assert path[-1] is not None
678         assert path[-1] == node
679         path = path[:-1]
680         path.reverse()
681         for ancestor in path:
682             assert ancestor is not None
683             if node != ancestor.right:
684                 return ancestor
685             node = ancestor
686         return None
687
688     def get_nodes_in_range_inclusive(
689         self, lower: ComparableNodeValue, upper: ComparableNodeValue
690     ):
691         """
692         >>> t = BinarySearchTree()
693         >>> t.insert(50)
694         >>> t.insert(75)
695         >>> t.insert(25)
696         >>> t.insert(66)
697         >>> t.insert(22)
698         >>> t.insert(13)
699         >>> t.insert(23)
700         >>> t
701         50
702         ├──25
703         │  └──22
704         │     ├──13
705         │     └──23
706         └──75
707            └──66
708
709         >>> for node in t.get_nodes_in_range_inclusive(21, 74):
710         ...     print(node.value)
711         22
712         23
713         25
714         50
715         66
716         """
717         node: Optional[Node] = self._find_lowest_value_greater_than_or_equal_to(
718             lower, self.root
719         )
720         while node:
721             if lower <= node.value <= upper:
722                 yield node
723             node = self.get_next_node(node)
724
725     def _depth(self, node: Node, sofar: int) -> int:
726         depth_left = sofar + 1
727         depth_right = sofar + 1
728         if node.left is not None:
729             depth_left = self._depth(node.left, sofar + 1)
730         if node.right is not None:
731             depth_right = self._depth(node.right, sofar + 1)
732         return max(depth_left, depth_right)
733
734     def depth(self) -> int:
735         """
736         Returns:
737             The max height (depth) of the tree in plies (edge distance
738             from root).
739
740         >>> t = BinarySearchTree()
741         >>> t.depth()
742         0
743
744         >>> t.insert(50)
745         >>> t.depth()
746         1
747
748         >>> t.insert(65)
749         >>> t.depth()
750         2
751
752         >>> t.insert(33)
753         >>> t.depth()
754         2
755
756         >>> t.insert(2)
757         >>> t.insert(1)
758         >>> t.depth()
759         4
760
761         """
762         if self.root is None:
763             return 0
764         return self._depth(self.root, 0)
765
766     def height(self) -> int:
767         """Returns the height (i.e. max depth) of the tree"""
768         return self.depth()
769
770     def repr_traverse(
771         self,
772         padding: str,
773         pointer: str,
774         node: Optional[Node],
775         has_right_sibling: bool,
776     ) -> str:
777         if node is not None:
778             viz = f"\n{padding}{pointer}{node.value}"
779             if has_right_sibling:
780                 padding += "│  "
781             else:
782                 padding += "   "
783
784             pointer_right = "└──"
785             if node.right is not None:
786                 pointer_left = "├──"
787             else:
788                 pointer_left = "└──"
789
790             viz += self.repr_traverse(
791                 padding, pointer_left, node.left, node.right is not None
792             )
793             viz += self.repr_traverse(padding, pointer_right, node.right, False)
794             return viz
795         return ""
796
797     def __repr__(self):
798         """
799         Returns:
800             An ASCII string representation of the tree.
801
802         >>> t = BinarySearchTree()
803         >>> t.insert(50)
804         >>> t.insert(25)
805         >>> t.insert(75)
806         >>> t.insert(12)
807         >>> t.insert(33)
808         >>> t.insert(88)
809         >>> t.insert(55)
810         >>> t
811         50
812         ├──25
813         │  ├──12
814         │  └──33
815         └──75
816            ├──55
817            └──88
818         """
819         if self.root is None:
820             return ""
821
822         ret = f"{self.root.value}"
823         pointer_right = "└──"
824         if self.root.right is None:
825             pointer_left = "└──"
826         else:
827             pointer_left = "├──"
828
829         ret += self.repr_traverse(
830             "", pointer_left, self.root.left, self.root.left is not None
831         )
832         ret += self.repr_traverse("", pointer_right, self.root.right, False)
833         return ret
834
835
836 if __name__ == "__main__":
837     import doctest
838
839     doctest.testmod()