Better definition of "fuzzy".
[pyutils.git] / src / pyutils / collectionz / bst.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2023, Scott Gasch
4
5 """A binary search tree implementation."""
6
7 from typing import Any, Generator, List, Optional
8
9
10 class Node(object):
11     def __init__(self, value: Any) -> None:
12         """
13         A BST node.  Note that value can be anything as long as it
14         is comparable.  Check out :meth:`functools.total_ordering`
15         (https://docs.python.org/3/library/functools.html#functools.total_ordering)
16
17         Args:
18             value: a reference to the value of the node.
19         """
20         self.left: Optional[Node] = None
21         self.right: Optional[Node] = None
22         self.value = value
23
24
25 class BinarySearchTree(object):
26     def __init__(self):
27         self.root = None
28         self.count = 0
29         self.traverse = None
30
31     def get_root(self) -> Optional[Node]:
32         """
33         Returns:
34             The root of the BST
35         """
36
37         return self.root
38
39     def _on_insert(self, parent: Optional[Node], new: Node) -> None:
40         """This is called immediately _after_ a new node is inserted."""
41         pass
42
43     def insert(self, value: Any) -> None:
44         """
45         Insert something into the tree.
46
47         Args:
48             value: the value to be inserted.
49
50         >>> t = BinarySearchTree()
51         >>> t.insert(10)
52         >>> t.insert(20)
53         >>> t.insert(5)
54         >>> len(t)
55         3
56
57         >>> t.get_root().value
58         10
59
60         """
61         if self.root is None:
62             self.root = Node(value)
63             self.count = 1
64             self._on_insert(None, self.root)
65         else:
66             self._insert(value, self.root)
67
68     def _insert(self, value: Any, node: Node):
69         """Insertion helper"""
70         if value < node.value:
71             if node.left is not None:
72                 self._insert(value, node.left)
73             else:
74                 node.left = Node(value)
75                 self.count += 1
76                 self._on_insert(node, node.left)
77         else:
78             if node.right is not None:
79                 self._insert(value, node.right)
80             else:
81                 node.right = Node(value)
82                 self.count += 1
83                 self._on_insert(node, node.right)
84
85     def __getitem__(self, value: Any) -> Optional[Node]:
86         """
87         Find an item in the tree and return its Node.  Returns
88         None if the item is not in the tree.
89
90         >>> t = BinarySearchTree()
91         >>> t[99]
92
93         >>> t.insert(10)
94         >>> t.insert(20)
95         >>> t.insert(5)
96         >>> t[10].value
97         10
98
99         >>> t[99]
100
101         """
102         if self.root is not None:
103             return self._find(value, self.root)
104         return None
105
106     def _find(self, value: Any, node: Node) -> Optional[Node]:
107         """Find helper"""
108         if value == node.value:
109             return node
110         elif value < node.value and node.left is not None:
111             return self._find(value, node.left)
112         elif value > node.value and node.right is not None:
113             return self._find(value, node.right)
114         return None
115
116     def _find_lowest_value_greater_than_or_equal_to(
117         self, target: Any, node: Optional[Node]
118     ) -> Optional[Node]:
119         """Find helper that returns the lowest node that is greater
120         than or equal to the target value.  Returns None if target is
121         greater than the highest node in the tree.
122
123         >>> t = BinarySearchTree()
124         >>> t.insert(50)
125         >>> t.insert(75)
126         >>> t.insert(25)
127         >>> t.insert(66)
128         >>> t.insert(22)
129         >>> t.insert(13)
130         >>> t.insert(85)
131         >>> t
132         50
133         ├──25
134         │  └──22
135         │     └──13
136         └──75
137            ├──66
138            └──85
139
140         >>> t._find_lowest_value_greater_than_or_equal_to(48, t.root).value
141         50
142         >>> t._find_lowest_value_greater_than_or_equal_to(55, t.root).value
143         66
144         >>> t._find_lowest_value_greater_than_or_equal_to(1, t.root).value
145         13
146         >>> t._find_lowest_value_greater_than_or_equal_to(24, t.root).value
147         25
148         >>> t._find_lowest_value_greater_than_or_equal_to(20, t.root).value
149         22
150         >>> t._find_lowest_value_greater_than_or_equal_to(72, t.root).value
151         75
152         >>> t._find_lowest_value_greater_than_or_equal_to(78, t.root).value
153         85
154         >>> t._find_lowest_value_greater_than_or_equal_to(95, t.root) is None
155         True
156
157         """
158
159         if not node:
160             return None
161
162         if target == node.value:
163             return node
164         elif target >= node.value:
165             return self._find_lowest_value_greater_than_or_equal_to(target, node.right)
166         else:
167             assert target < node.value
168             if below := self._find_lowest_value_greater_than_or_equal_to(
169                 target, node.left
170             ):
171                 return below
172             else:
173                 return node
174
175     def _parent_path(
176         self, current: Optional[Node], target: Node
177     ) -> List[Optional[Node]]:
178         """Internal helper"""
179         if current is None:
180             return [None]
181         ret: List[Optional[Node]] = [current]
182         if target.value == current.value:
183             return ret
184         elif target.value < current.value:
185             ret.extend(self._parent_path(current.left, target))
186             return ret
187         else:
188             assert target.value > current.value
189             ret.extend(self._parent_path(current.right, target))
190             return ret
191
192     def parent_path(self, node: Node) -> List[Optional[Node]]:
193         """Get a node's parent path.
194
195         Args:
196             node: the node to check
197
198         Returns:
199             a list of nodes representing the path from
200             the tree's root to the node.
201
202         .. note::
203
204             If the node does not exist in the tree, the last element
205             on the path will be None but the path will indicate the
206             ancestor path of that node were it to be inserted.
207
208         >>> t = BinarySearchTree()
209         >>> t.insert(50)
210         >>> t.insert(75)
211         >>> t.insert(25)
212         >>> t.insert(12)
213         >>> t.insert(33)
214         >>> t.insert(4)
215         >>> t.insert(88)
216         >>> t
217         50
218         ├──25
219         │  ├──12
220         │  │  └──4
221         │  └──33
222         └──75
223            └──88
224
225         >>> n = t[4]
226         >>> for x in t.parent_path(n):
227         ...     print(x.value)
228         50
229         25
230         12
231         4
232
233         >>> del t[4]
234         >>> for x in t.parent_path(n):
235         ...     if x is not None:
236         ...         print(x.value)
237         ...     else:
238         ...         print(x)
239         50
240         25
241         12
242         None
243
244         """
245         return self._parent_path(self.root, node)
246
247     def __delitem__(self, value: Any) -> bool:
248         """
249         Delete an item from the tree and preserve the BST property.
250
251         Args:
252             value: the value of the node to be deleted.
253
254         Returns:
255             True if the value was found and its associated node was
256             successfully deleted and False otherwise.
257
258         >>> t = BinarySearchTree()
259         >>> t.insert(50)
260         >>> t.insert(75)
261         >>> t.insert(25)
262         >>> t.insert(66)
263         >>> t.insert(22)
264         >>> t.insert(13)
265         >>> t.insert(85)
266         >>> t
267         50
268         ├──25
269         │  └──22
270         │     └──13
271         └──75
272            ├──66
273            └──85
274
275         >>> for value in t.iterate_inorder():
276         ...     print(value)
277         13
278         22
279         25
280         50
281         66
282         75
283         85
284
285         >>> del t[22]  # Note: bool result is discarded
286
287         >>> for value in t.iterate_inorder():
288         ...     print(value)
289         13
290         25
291         50
292         66
293         75
294         85
295
296         >>> t.__delitem__(13)
297         True
298         >>> for value in t.iterate_inorder():
299         ...     print(value)
300         25
301         50
302         66
303         75
304         85
305
306         >>> t.__delitem__(75)
307         True
308         >>> for value in t.iterate_inorder():
309         ...     print(value)
310         25
311         50
312         66
313         85
314         >>> t
315         50
316         ├──25
317         └──85
318            └──66
319
320         >>> t.__delitem__(99)
321         False
322
323         """
324         if self.root is not None:
325             ret = self._delete(value, None, self.root)
326             if ret:
327                 self.count -= 1
328                 if self.count == 0:
329                     self.root = None
330             return ret
331         return False
332
333     def _on_delete(self, parent: Optional[Node], deleted: Node) -> None:
334         """This is called just after deleted was deleted from the tree"""
335         pass
336
337     def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool:
338         """Delete helper"""
339         if node.value == value:
340
341             # Deleting a leaf node
342             if node.left is None and node.right is None:
343                 if parent is not None:
344                     if parent.left == node:
345                         parent.left = None
346                     else:
347                         assert parent.right == node
348                         parent.right = None
349                 self._on_delete(parent, node)
350                 return True
351
352             # Node only has a right.
353             elif node.left is None:
354                 assert node.right is not None
355                 if parent is not None:
356                     if parent.left == node:
357                         parent.left = node.right
358                     else:
359                         assert parent.right == node
360                         parent.right = node.right
361                 self._on_delete(parent, node)
362                 return True
363
364             # Node only has a left.
365             elif node.right is None:
366                 assert node.left is not None
367                 if parent is not None:
368                     if parent.left == node:
369                         parent.left = node.left
370                     else:
371                         assert parent.right == node
372                         parent.right = node.left
373                 self._on_delete(parent, node)
374                 return True
375
376             # Node has both a left and right.
377             else:
378                 assert node.left is not None and node.right is not None
379                 descendent = node.right
380                 while descendent.left is not None:
381                     descendent = descendent.left
382                 node.value = descendent.value
383                 return self._delete(node.value, node, node.right)
384         elif value < node.value and node.left is not None:
385             return self._delete(value, node, node.left)
386         elif value > node.value and node.right is not None:
387             return self._delete(value, node, node.right)
388         return False
389
390     def __len__(self):
391         """
392         Returns:
393             The count of items in the tree.
394
395         >>> t = BinarySearchTree()
396         >>> len(t)
397         0
398         >>> t.insert(50)
399         >>> len(t)
400         1
401         >>> t.__delitem__(50)
402         True
403         >>> len(t)
404         0
405         >>> t.insert(75)
406         >>> t.insert(25)
407         >>> t.insert(66)
408         >>> t.insert(22)
409         >>> t.insert(13)
410         >>> t.insert(85)
411         >>> len(t)
412         6
413
414         """
415         return self.count
416
417     def __contains__(self, value: Any) -> bool:
418         """
419         Returns:
420             True if the item is in the tree; False otherwise.
421         """
422         return self.__getitem__(value) is not None
423
424     def _iterate_preorder(self, node: Node):
425         yield node.value
426         if node.left is not None:
427             yield from self._iterate_preorder(node.left)
428         if node.right is not None:
429             yield from self._iterate_preorder(node.right)
430
431     def _iterate_inorder(self, node: Node):
432         if node.left is not None:
433             yield from self._iterate_inorder(node.left)
434         yield node.value
435         if node.right is not None:
436             yield from self._iterate_inorder(node.right)
437
438     def _iterate_postorder(self, node: Node):
439         if node.left is not None:
440             yield from self._iterate_postorder(node.left)
441         if node.right is not None:
442             yield from self._iterate_postorder(node.right)
443         yield node.value
444
445     def iterate_preorder(self):
446         """
447         Returns:
448             A Generator that yields the tree's items in a
449             preorder traversal sequence.
450
451         >>> t = BinarySearchTree()
452         >>> t.insert(50)
453         >>> t.insert(75)
454         >>> t.insert(25)
455         >>> t.insert(66)
456         >>> t.insert(22)
457         >>> t.insert(13)
458
459         >>> for value in t.iterate_preorder():
460         ...     print(value)
461         50
462         25
463         22
464         13
465         75
466         66
467
468         """
469         if self.root is not None:
470             yield from self._iterate_preorder(self.root)
471
472     def iterate_inorder(self):
473         """
474         Returns:
475             A Generator that yield the tree's items in a preorder
476             traversal sequence.
477
478         >>> t = BinarySearchTree()
479         >>> t.insert(50)
480         >>> t.insert(75)
481         >>> t.insert(25)
482         >>> t.insert(66)
483         >>> t.insert(22)
484         >>> t.insert(13)
485         >>> t.insert(24)
486         >>> t
487         50
488         ├──25
489         │  └──22
490         │     ├──13
491         │     └──24
492         └──75
493            └──66
494
495         >>> for value in t.iterate_inorder():
496         ...     print(value)
497         13
498         22
499         24
500         25
501         50
502         66
503         75
504
505         """
506         if self.root is not None:
507             yield from self._iterate_inorder(self.root)
508
509     def iterate_postorder(self):
510         """
511         Returns:
512             A Generator that yield the tree's items in a preorder
513             traversal sequence.
514
515         >>> t = BinarySearchTree()
516         >>> t.insert(50)
517         >>> t.insert(75)
518         >>> t.insert(25)
519         >>> t.insert(66)
520         >>> t.insert(22)
521         >>> t.insert(13)
522
523         >>> for value in t.iterate_postorder():
524         ...     print(value)
525         13
526         22
527         25
528         66
529         75
530         50
531
532         """
533         if self.root is not None:
534             yield from self._iterate_postorder(self.root)
535
536     def _iterate_leaves(self, node: Node):
537         if node.left is not None:
538             yield from self._iterate_leaves(node.left)
539         if node.right is not None:
540             yield from self._iterate_leaves(node.right)
541         if node.left is None and node.right is None:
542             yield node.value
543
544     def iterate_leaves(self):
545         """
546         Returns:
547             A Gemerator that yielde only the leaf nodes in the
548             tree.
549
550         >>> t = BinarySearchTree()
551         >>> t.insert(50)
552         >>> t.insert(75)
553         >>> t.insert(25)
554         >>> t.insert(66)
555         >>> t.insert(22)
556         >>> t.insert(13)
557
558         >>> for value in t.iterate_leaves():
559         ...     print(value)
560         13
561         66
562
563         """
564         if self.root is not None:
565             yield from self._iterate_leaves(self.root)
566
567     def _iterate_by_depth(self, node: Node, depth: int):
568         if depth == 0:
569             yield node.value
570         else:
571             assert depth > 0
572             if node.left is not None:
573                 yield from self._iterate_by_depth(node.left, depth - 1)
574             if node.right is not None:
575                 yield from self._iterate_by_depth(node.right, depth - 1)
576
577     def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
578         """
579         Args:
580             depth: the desired depth
581
582         Returns:
583             A Generator that yields nodes at the prescribed depth in
584             the tree.
585
586         >>> t = BinarySearchTree()
587         >>> t.insert(50)
588         >>> t.insert(75)
589         >>> t.insert(25)
590         >>> t.insert(66)
591         >>> t.insert(22)
592         >>> t.insert(13)
593
594         >>> for value in t.iterate_nodes_by_depth(2):
595         ...     print(value)
596         22
597         66
598
599         >>> for value in t.iterate_nodes_by_depth(3):
600         ...     print(value)
601         13
602
603         """
604         if self.root is not None:
605             yield from self._iterate_by_depth(self.root, depth)
606
607     def get_next_node(self, node: Node) -> Optional[Node]:
608         """
609         Args:
610             node: the node whose next greater successor is desired
611
612         Returns:
613             Given a tree node, returns the next greater node in the tree.
614             If the given node is the greatest node in the tree, returns None.
615
616         >>> t = BinarySearchTree()
617         >>> t.insert(50)
618         >>> t.insert(75)
619         >>> t.insert(25)
620         >>> t.insert(66)
621         >>> t.insert(22)
622         >>> t.insert(13)
623         >>> t.insert(23)
624         >>> t
625         50
626         ├──25
627         │  └──22
628         │     ├──13
629         │     └──23
630         └──75
631            └──66
632
633         >>> n = t[23]
634         >>> t.get_next_node(n).value
635         25
636
637         >>> n = t[50]
638         >>> t.get_next_node(n).value
639         66
640
641         >>> n = t[75]
642         >>> t.get_next_node(n) is None
643         True
644
645         """
646         if node.right is not None:
647             x = node.right
648             while x.left is not None:
649                 x = x.left
650             return x
651
652         path = self.parent_path(node)
653         assert path[-1] is not None
654         assert path[-1] == node
655         path = path[:-1]
656         path.reverse()
657         for ancestor in path:
658             assert ancestor is not None
659             if node != ancestor.right:
660                 return ancestor
661             node = ancestor
662         return None
663
664     def get_nodes_in_range_inclusive(self, lower: Any, upper: Any):
665         """
666         >>> t = BinarySearchTree()
667         >>> t.insert(50)
668         >>> t.insert(75)
669         >>> t.insert(25)
670         >>> t.insert(66)
671         >>> t.insert(22)
672         >>> t.insert(13)
673         >>> t.insert(23)
674         >>> t
675         50
676         ├──25
677         │  └──22
678         │     ├──13
679         │     └──23
680         └──75
681            └──66
682
683         >>> for node in t.get_nodes_in_range_inclusive(21, 74):
684         ...     print(node.value)
685         22
686         23
687         25
688         50
689         66
690         """
691         node: Optional[Node] = self._find_lowest_value_greater_than_or_equal_to(
692             lower, self.root
693         )
694         while node:
695             if lower <= node.value <= upper:
696                 yield node
697             node = self.get_next_node(node)
698
699     def _depth(self, node: Node, sofar: int) -> int:
700         depth_left = sofar + 1
701         depth_right = sofar + 1
702         if node.left is not None:
703             depth_left = self._depth(node.left, sofar + 1)
704         if node.right is not None:
705             depth_right = self._depth(node.right, sofar + 1)
706         return max(depth_left, depth_right)
707
708     def depth(self) -> int:
709         """
710         Returns:
711             The max height (depth) of the tree in plies (edge distance
712             from root).
713
714         >>> t = BinarySearchTree()
715         >>> t.depth()
716         0
717
718         >>> t.insert(50)
719         >>> t.depth()
720         1
721
722         >>> t.insert(65)
723         >>> t.depth()
724         2
725
726         >>> t.insert(33)
727         >>> t.depth()
728         2
729
730         >>> t.insert(2)
731         >>> t.insert(1)
732         >>> t.depth()
733         4
734
735         """
736         if self.root is None:
737             return 0
738         return self._depth(self.root, 0)
739
740     def height(self) -> int:
741         """Returns the height (i.e. max depth) of the tree"""
742         return self.depth()
743
744     def repr_traverse(
745         self,
746         padding: str,
747         pointer: str,
748         node: Optional[Node],
749         has_right_sibling: bool,
750     ) -> str:
751         if node is not None:
752             viz = f"\n{padding}{pointer}{node.value}"
753             if has_right_sibling:
754                 padding += "│  "
755             else:
756                 padding += "   "
757
758             pointer_right = "└──"
759             if node.right is not None:
760                 pointer_left = "├──"
761             else:
762                 pointer_left = "└──"
763
764             viz += self.repr_traverse(
765                 padding, pointer_left, node.left, node.right is not None
766             )
767             viz += self.repr_traverse(padding, pointer_right, node.right, False)
768             return viz
769         return ""
770
771     def __repr__(self):
772         """
773         Returns:
774             An ASCII string representation of the tree.
775
776         >>> t = BinarySearchTree()
777         >>> t.insert(50)
778         >>> t.insert(25)
779         >>> t.insert(75)
780         >>> t.insert(12)
781         >>> t.insert(33)
782         >>> t.insert(88)
783         >>> t.insert(55)
784         >>> t
785         50
786         ├──25
787         │  ├──12
788         │  └──33
789         └──75
790            ├──55
791            └──88
792         """
793         if self.root is None:
794             return ""
795
796         ret = f"{self.root.value}"
797         pointer_right = "└──"
798         if self.root.right is None:
799             pointer_left = "└──"
800         else:
801             pointer_left = "├──"
802
803         ret += self.repr_traverse(
804             "", pointer_left, self.root.left, self.root.left is not None
805         )
806         ret += self.repr_traverse("", pointer_right, self.root.right, False)
807         return ret
808
809
810 if __name__ == "__main__":
811     import doctest
812
813     doctest.testmod()