a4316a52a5402acec65328ca6d9f3610bbff62e3
[pyutils.git] / src / pyutils / collectionz / bst.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2023, Scott Gasch
4
5 """A binary search tree implementation."""
6
7 from typing import Any, Generator, List, Optional
8
9
10 class Node(object):
11     def __init__(self, value: Any) -> None:
12         """
13         A BST node.  Note that value can be anything as long as it
14         is comparable.  Check out :meth:`functools.total_ordering`
15         (https://docs.python.org/3/library/functools.html#functools.total_ordering)
16
17         Args:
18             value: a reference to the value of the node.
19         """
20         self.left: Optional[Node] = None
21         self.right: Optional[Node] = None
22         self.value = value
23
24
25 class BinarySearchTree(object):
26     def __init__(self):
27         self.root = None
28         self.count = 0
29         self.traverse = None
30
31     def get_root(self) -> Optional[Node]:
32         """
33         Returns:
34             The root of the BST
35         """
36
37         return self.root
38
39     def _on_insert(self, parent: Optional[Node], new: Node) -> None:
40         """This is called immediately _after_ a new node is inserted."""
41         pass
42
43     def insert(self, value: Any) -> None:
44         """
45         Insert something into the tree.
46
47         Args:
48             value: the value to be inserted.
49
50         >>> t = BinarySearchTree()
51         >>> t.insert(10)
52         >>> t.insert(20)
53         >>> t.insert(5)
54         >>> len(t)
55         3
56
57         >>> t.get_root().value
58         10
59
60         """
61         if self.root is None:
62             self.root = Node(value)
63             self.count = 1
64             self._on_insert(None, self.root)
65         else:
66             self._insert(value, self.root)
67
68     def _insert(self, value: Any, node: Node):
69         """Insertion helper"""
70         if value < node.value:
71             if node.left is not None:
72                 self._insert(value, node.left)
73             else:
74                 node.left = Node(value)
75                 self.count += 1
76                 self._on_insert(node, node.left)
77         else:
78             if node.right is not None:
79                 self._insert(value, node.right)
80             else:
81                 node.right = Node(value)
82                 self.count += 1
83                 self._on_insert(node, node.right)
84
85     def __getitem__(self, value: Any) -> Optional[Node]:
86         """
87         Find an item in the tree and return its Node.  Returns
88         None if the item is not in the tree.
89
90         >>> t = BinarySearchTree()
91         >>> t[99]
92
93         >>> t.insert(10)
94         >>> t.insert(20)
95         >>> t.insert(5)
96         >>> t[10].value
97         10
98
99         >>> t[99]
100
101         """
102         if self.root is not None:
103             return self._find(value, self.root)
104         return None
105
106     def _find(self, value: Any, node: Node) -> Optional[Node]:
107         """Find helper"""
108         if value == node.value:
109             return node
110         elif value < node.value and node.left is not None:
111             return self._find(value, node.left)
112         elif value > node.value and node.right is not None:
113             return self._find(value, node.right)
114         return None
115
116     def _find_fuzzy(self, value: Any, node: Node) -> Node:
117         """Find helper that returns the closest node <= the target value."""
118         if value == node.value:
119             return node
120         below = None
121         if value < node.value and node.left is not None:
122             below = self._find_fuzzy(value, node.left)
123         elif value > node.value and node.right is not None:
124             below = self._find_fuzzy(value, node.right)
125         if not below:
126             return node
127         return below
128
129     def _parent_path(
130         self, current: Optional[Node], target: Node
131     ) -> List[Optional[Node]]:
132         """Internal helper"""
133         if current is None:
134             return [None]
135         ret: List[Optional[Node]] = [current]
136         if target.value == current.value:
137             return ret
138         elif target.value < current.value:
139             ret.extend(self._parent_path(current.left, target))
140             return ret
141         else:
142             assert target.value > current.value
143             ret.extend(self._parent_path(current.right, target))
144             return ret
145
146     def parent_path(self, node: Node) -> List[Optional[Node]]:
147         """Get a node's parent path.
148
149         Args:
150             node: the node to check
151
152         Returns:
153             a list of nodes representing the path from
154             the tree's root to the node.
155
156         .. note::
157
158             If the node does not exist in the tree, the last element
159             on the path will be None but the path will indicate the
160             ancestor path of that node were it to be inserted.
161
162         >>> t = BinarySearchTree()
163         >>> t.insert(50)
164         >>> t.insert(75)
165         >>> t.insert(25)
166         >>> t.insert(12)
167         >>> t.insert(33)
168         >>> t.insert(4)
169         >>> t.insert(88)
170         >>> t
171         50
172         ├──25
173         │  ├──12
174         │  │  └──4
175         │  └──33
176         └──75
177            └──88
178
179         >>> n = t[4]
180         >>> for x in t.parent_path(n):
181         ...     print(x.value)
182         50
183         25
184         12
185         4
186
187         >>> del t[4]
188         >>> for x in t.parent_path(n):
189         ...     if x is not None:
190         ...         print(x.value)
191         ...     else:
192         ...         print(x)
193         50
194         25
195         12
196         None
197
198         """
199         return self._parent_path(self.root, node)
200
201     def __delitem__(self, value: Any) -> bool:
202         """
203         Delete an item from the tree and preserve the BST property.
204
205         Args:
206             value: the value of the node to be deleted.
207
208         Returns:
209             True if the value was found and its associated node was
210             successfully deleted and False otherwise.
211
212         >>> t = BinarySearchTree()
213         >>> t.insert(50)
214         >>> t.insert(75)
215         >>> t.insert(25)
216         >>> t.insert(66)
217         >>> t.insert(22)
218         >>> t.insert(13)
219         >>> t.insert(85)
220         >>> t
221         50
222         ├──25
223         │  └──22
224         │     └──13
225         └──75
226            ├──66
227            └──85
228
229         >>> for value in t.iterate_inorder():
230         ...     print(value)
231         13
232         22
233         25
234         50
235         66
236         75
237         85
238
239         >>> del t[22]  # Note: bool result is discarded
240
241         >>> for value in t.iterate_inorder():
242         ...     print(value)
243         13
244         25
245         50
246         66
247         75
248         85
249
250         >>> t.__delitem__(13)
251         True
252         >>> for value in t.iterate_inorder():
253         ...     print(value)
254         25
255         50
256         66
257         75
258         85
259
260         >>> t.__delitem__(75)
261         True
262         >>> for value in t.iterate_inorder():
263         ...     print(value)
264         25
265         50
266         66
267         85
268         >>> t
269         50
270         ├──25
271         └──85
272            └──66
273
274         >>> t.__delitem__(99)
275         False
276
277         """
278         if self.root is not None:
279             ret = self._delete(value, None, self.root)
280             if ret:
281                 self.count -= 1
282                 if self.count == 0:
283                     self.root = None
284             return ret
285         return False
286
287     def _on_delete(self, parent: Optional[Node], deleted: Node) -> None:
288         """This is called just after deleted was deleted from the tree"""
289         pass
290
291     def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool:
292         """Delete helper"""
293         if node.value == value:
294
295             # Deleting a leaf node
296             if node.left is None and node.right is None:
297                 if parent is not None:
298                     if parent.left == node:
299                         parent.left = None
300                     else:
301                         assert parent.right == node
302                         parent.right = None
303                 self._on_delete(parent, node)
304                 return True
305
306             # Node only has a right.
307             elif node.left is None:
308                 assert node.right is not None
309                 if parent is not None:
310                     if parent.left == node:
311                         parent.left = node.right
312                     else:
313                         assert parent.right == node
314                         parent.right = node.right
315                 self._on_delete(parent, node)
316                 return True
317
318             # Node only has a left.
319             elif node.right is None:
320                 assert node.left is not None
321                 if parent is not None:
322                     if parent.left == node:
323                         parent.left = node.left
324                     else:
325                         assert parent.right == node
326                         parent.right = node.left
327                 self._on_delete(parent, node)
328                 return True
329
330             # Node has both a left and right.
331             else:
332                 assert node.left is not None and node.right is not None
333                 descendent = node.right
334                 while descendent.left is not None:
335                     descendent = descendent.left
336                 node.value = descendent.value
337                 return self._delete(node.value, node, node.right)
338         elif value < node.value and node.left is not None:
339             return self._delete(value, node, node.left)
340         elif value > node.value and node.right is not None:
341             return self._delete(value, node, node.right)
342         return False
343
344     def __len__(self):
345         """
346         Returns:
347             The count of items in the tree.
348
349         >>> t = BinarySearchTree()
350         >>> len(t)
351         0
352         >>> t.insert(50)
353         >>> len(t)
354         1
355         >>> t.__delitem__(50)
356         True
357         >>> len(t)
358         0
359         >>> t.insert(75)
360         >>> t.insert(25)
361         >>> t.insert(66)
362         >>> t.insert(22)
363         >>> t.insert(13)
364         >>> t.insert(85)
365         >>> len(t)
366         6
367
368         """
369         return self.count
370
371     def __contains__(self, value: Any) -> bool:
372         """
373         Returns:
374             True if the item is in the tree; False otherwise.
375         """
376         return self.__getitem__(value) is not None
377
378     def _iterate_preorder(self, node: Node):
379         yield node.value
380         if node.left is not None:
381             yield from self._iterate_preorder(node.left)
382         if node.right is not None:
383             yield from self._iterate_preorder(node.right)
384
385     def _iterate_inorder(self, node: Node):
386         if node.left is not None:
387             yield from self._iterate_inorder(node.left)
388         yield node.value
389         if node.right is not None:
390             yield from self._iterate_inorder(node.right)
391
392     def _iterate_postorder(self, node: Node):
393         if node.left is not None:
394             yield from self._iterate_postorder(node.left)
395         if node.right is not None:
396             yield from self._iterate_postorder(node.right)
397         yield node.value
398
399     def iterate_preorder(self):
400         """
401         Returns:
402             A Generator that yields the tree's items in a
403             preorder traversal sequence.
404
405         >>> t = BinarySearchTree()
406         >>> t.insert(50)
407         >>> t.insert(75)
408         >>> t.insert(25)
409         >>> t.insert(66)
410         >>> t.insert(22)
411         >>> t.insert(13)
412
413         >>> for value in t.iterate_preorder():
414         ...     print(value)
415         50
416         25
417         22
418         13
419         75
420         66
421
422         """
423         if self.root is not None:
424             yield from self._iterate_preorder(self.root)
425
426     def iterate_inorder(self):
427         """
428         Returns:
429             A Generator that yield the tree's items in a preorder
430             traversal sequence.
431
432         >>> t = BinarySearchTree()
433         >>> t.insert(50)
434         >>> t.insert(75)
435         >>> t.insert(25)
436         >>> t.insert(66)
437         >>> t.insert(22)
438         >>> t.insert(13)
439         >>> t.insert(24)
440         >>> t
441         50
442         ├──25
443         │  └──22
444         │     ├──13
445         │     └──24
446         └──75
447            └──66
448
449         >>> for value in t.iterate_inorder():
450         ...     print(value)
451         13
452         22
453         24
454         25
455         50
456         66
457         75
458
459         """
460         if self.root is not None:
461             yield from self._iterate_inorder(self.root)
462
463     def iterate_postorder(self):
464         """
465         Returns:
466             A Generator that yield the tree's items in a preorder
467             traversal sequence.
468
469         >>> t = BinarySearchTree()
470         >>> t.insert(50)
471         >>> t.insert(75)
472         >>> t.insert(25)
473         >>> t.insert(66)
474         >>> t.insert(22)
475         >>> t.insert(13)
476
477         >>> for value in t.iterate_postorder():
478         ...     print(value)
479         13
480         22
481         25
482         66
483         75
484         50
485
486         """
487         if self.root is not None:
488             yield from self._iterate_postorder(self.root)
489
490     def _iterate_leaves(self, node: Node):
491         if node.left is not None:
492             yield from self._iterate_leaves(node.left)
493         if node.right is not None:
494             yield from self._iterate_leaves(node.right)
495         if node.left is None and node.right is None:
496             yield node.value
497
498     def iterate_leaves(self):
499         """
500         Returns:
501             A Gemerator that yielde only the leaf nodes in the
502             tree.
503
504         >>> t = BinarySearchTree()
505         >>> t.insert(50)
506         >>> t.insert(75)
507         >>> t.insert(25)
508         >>> t.insert(66)
509         >>> t.insert(22)
510         >>> t.insert(13)
511
512         >>> for value in t.iterate_leaves():
513         ...     print(value)
514         13
515         66
516
517         """
518         if self.root is not None:
519             yield from self._iterate_leaves(self.root)
520
521     def _iterate_by_depth(self, node: Node, depth: int):
522         if depth == 0:
523             yield node.value
524         else:
525             assert depth > 0
526             if node.left is not None:
527                 yield from self._iterate_by_depth(node.left, depth - 1)
528             if node.right is not None:
529                 yield from self._iterate_by_depth(node.right, depth - 1)
530
531     def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
532         """
533         Args:
534             depth: the desired depth
535
536         Returns:
537             A Generator that yields nodes at the prescribed depth in
538             the tree.
539
540         >>> t = BinarySearchTree()
541         >>> t.insert(50)
542         >>> t.insert(75)
543         >>> t.insert(25)
544         >>> t.insert(66)
545         >>> t.insert(22)
546         >>> t.insert(13)
547
548         >>> for value in t.iterate_nodes_by_depth(2):
549         ...     print(value)
550         22
551         66
552
553         >>> for value in t.iterate_nodes_by_depth(3):
554         ...     print(value)
555         13
556
557         """
558         if self.root is not None:
559             yield from self._iterate_by_depth(self.root, depth)
560
561     def get_next_node(self, node: Node) -> Optional[Node]:
562         """
563         Args:
564             node: the node whose next greater successor is desired
565
566         Returns:
567             Given a tree node, returns the next greater node in the tree.
568             If the given node is the greatest node in the tree, returns None.
569
570         >>> t = BinarySearchTree()
571         >>> t.insert(50)
572         >>> t.insert(75)
573         >>> t.insert(25)
574         >>> t.insert(66)
575         >>> t.insert(22)
576         >>> t.insert(13)
577         >>> t.insert(23)
578         >>> t
579         50
580         ├──25
581         │  └──22
582         │     ├──13
583         │     └──23
584         └──75
585            └──66
586
587         >>> n = t[23]
588         >>> t.get_next_node(n).value
589         25
590
591         >>> n = t[50]
592         >>> t.get_next_node(n).value
593         66
594
595         >>> n = t[75]
596         >>> t.get_next_node(n) is None
597         True
598
599         """
600         if node.right is not None:
601             x = node.right
602             while x.left is not None:
603                 x = x.left
604             return x
605
606         path = self.parent_path(node)
607         assert path[-1] is not None
608         assert path[-1] == node
609         path = path[:-1]
610         path.reverse()
611         for ancestor in path:
612             assert ancestor is not None
613             if node != ancestor.right:
614                 return ancestor
615             node = ancestor
616         return None
617
618     def get_nodes_in_range_inclusive(self, lower: Any, upper: Any):
619         """
620         >>> t = BinarySearchTree()
621         >>> t.insert(50)
622         >>> t.insert(75)
623         >>> t.insert(25)
624         >>> t.insert(66)
625         >>> t.insert(22)
626         >>> t.insert(13)
627         >>> t.insert(23)
628         >>> t
629         50
630         ├──25
631         │  └──22
632         │     ├──13
633         │     └──23
634         └──75
635            └──66
636
637         >>> for node in t.get_nodes_in_range_inclusive(21, 74):
638         ...     print(node.value)
639         22
640         23
641         25
642         50
643         66
644         """
645         node: Optional[Node] = self._find_fuzzy(lower, self.root)
646         while node:
647             if lower <= node.value <= upper:
648                 yield node
649             node = self.get_next_node(node)
650
651     def _depth(self, node: Node, sofar: int) -> int:
652         depth_left = sofar + 1
653         depth_right = sofar + 1
654         if node.left is not None:
655             depth_left = self._depth(node.left, sofar + 1)
656         if node.right is not None:
657             depth_right = self._depth(node.right, sofar + 1)
658         return max(depth_left, depth_right)
659
660     def depth(self) -> int:
661         """
662         Returns:
663             The max height (depth) of the tree in plies (edge distance
664             from root).
665
666         >>> t = BinarySearchTree()
667         >>> t.depth()
668         0
669
670         >>> t.insert(50)
671         >>> t.depth()
672         1
673
674         >>> t.insert(65)
675         >>> t.depth()
676         2
677
678         >>> t.insert(33)
679         >>> t.depth()
680         2
681
682         >>> t.insert(2)
683         >>> t.insert(1)
684         >>> t.depth()
685         4
686
687         """
688         if self.root is None:
689             return 0
690         return self._depth(self.root, 0)
691
692     def height(self) -> int:
693         """Returns the height (i.e. max depth) of the tree"""
694         return self.depth()
695
696     def repr_traverse(
697         self,
698         padding: str,
699         pointer: str,
700         node: Optional[Node],
701         has_right_sibling: bool,
702     ) -> str:
703         if node is not None:
704             viz = f"\n{padding}{pointer}{node.value}"
705             if has_right_sibling:
706                 padding += "│  "
707             else:
708                 padding += "   "
709
710             pointer_right = "└──"
711             if node.right is not None:
712                 pointer_left = "├──"
713             else:
714                 pointer_left = "└──"
715
716             viz += self.repr_traverse(
717                 padding, pointer_left, node.left, node.right is not None
718             )
719             viz += self.repr_traverse(padding, pointer_right, node.right, False)
720             return viz
721         return ""
722
723     def __repr__(self):
724         """
725         Returns:
726             An ASCII string representation of the tree.
727
728         >>> t = BinarySearchTree()
729         >>> t.insert(50)
730         >>> t.insert(25)
731         >>> t.insert(75)
732         >>> t.insert(12)
733         >>> t.insert(33)
734         >>> t.insert(88)
735         >>> t.insert(55)
736         >>> t
737         50
738         ├──25
739         │  ├──12
740         │  └──33
741         └──75
742            ├──55
743            └──88
744         """
745         if self.root is None:
746             return ""
747
748         ret = f"{self.root.value}"
749         pointer_right = "└──"
750         if self.root.right is None:
751             pointer_left = "└──"
752         else:
753             pointer_left = "├──"
754
755         ret += self.repr_traverse(
756             "", pointer_left, self.root.left, self.root.left is not None
757         )
758         ret += self.repr_traverse("", pointer_right, self.root.right, False)
759         return ret
760
761
762 if __name__ == "__main__":
763     import doctest
764
765     doctest.testmod()