Fix interval_tree so it actually works. Add unittests.
[pyutils.git] / src / pyutils / collectionz / bst.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """A binary search tree implementation."""
6
7 from typing import Any, Generator, List, Optional
8
9
10 class Node(object):
11     def __init__(self, value: Any) -> None:
12         """
13         A BST node.  Note that value can be anything as long as it
14         is comparable.  Check out :meth:`functools.total_ordering`
15         (https://docs.python.org/3/library/functools.html#functools.total_ordering)
16
17         Args:
18             value: a reference to the value of the node.
19         """
20         self.left: Optional[Node] = None
21         self.right: Optional[Node] = None
22         self.value = value
23
24
25 class BinarySearchTree(object):
26     def __init__(self):
27         self.root = None
28         self.count = 0
29         self.traverse = None
30
31     def get_root(self) -> Optional[Node]:
32         """
33         Returns:
34             The root of the BST
35         """
36
37         return self.root
38
39     def _on_insert(self, parent: Optional[Node], new: Node) -> None:
40         """This is called immediately _after_ a new node is inserted."""
41         pass
42
43     def insert(self, value: Any) -> None:
44         """
45         Insert something into the tree.
46
47         Args:
48             value: the value to be inserted.
49
50         >>> t = BinarySearchTree()
51         >>> t.insert(10)
52         >>> t.insert(20)
53         >>> t.insert(5)
54         >>> len(t)
55         3
56
57         >>> t.get_root().value
58         10
59
60         """
61         if self.root is None:
62             self.root = Node(value)
63             self.count = 1
64             self._on_insert(None, self.root)
65         else:
66             self._insert(value, self.root)
67
68     def _insert(self, value: Any, node: Node):
69         """Insertion helper"""
70         if value < node.value:
71             if node.left is not None:
72                 self._insert(value, node.left)
73             else:
74                 node.left = Node(value)
75                 self.count += 1
76                 self._on_insert(node, node.left)
77         else:
78             if node.right is not None:
79                 self._insert(value, node.right)
80             else:
81                 node.right = Node(value)
82                 self.count += 1
83                 self._on_insert(node, node.right)
84
85     def __getitem__(self, value: Any) -> Optional[Node]:
86         """
87         Find an item in the tree and return its Node.  Returns
88         None if the item is not in the tree.
89
90         >>> t = BinarySearchTree()
91         >>> t[99]
92
93         >>> t.insert(10)
94         >>> t.insert(20)
95         >>> t.insert(5)
96         >>> t[10].value
97         10
98
99         >>> t[99]
100
101         """
102         if self.root is not None:
103             return self._find(value, self.root)
104         return None
105
106     def _find(self, value: Any, node: Node) -> Optional[Node]:
107         """Find helper"""
108         if value == node.value:
109             return node
110         elif value < node.value and node.left is not None:
111             return self._find(value, node.left)
112         elif value > node.value and node.right is not None:
113             return self._find(value, node.right)
114         return None
115
116     def _parent_path(
117         self, current: Optional[Node], target: Node
118     ) -> List[Optional[Node]]:
119         """Internal helper"""
120         if current is None:
121             return [None]
122         ret: List[Optional[Node]] = [current]
123         if target.value == current.value:
124             return ret
125         elif target.value < current.value:
126             ret.extend(self._parent_path(current.left, target))
127             return ret
128         else:
129             assert target.value > current.value
130             ret.extend(self._parent_path(current.right, target))
131             return ret
132
133     def parent_path(self, node: Node) -> List[Optional[Node]]:
134         """Get a node's parent path.
135
136         Args:
137             node: the node to check
138
139         Returns:
140             a list of nodes representing the path from
141             the tree's root to the node.
142
143         .. note::
144
145             If the node does not exist in the tree, the last element
146             on the path will be None but the path will indicate the
147             ancestor path of that node were it to be inserted.
148
149         >>> t = BinarySearchTree()
150         >>> t.insert(50)
151         >>> t.insert(75)
152         >>> t.insert(25)
153         >>> t.insert(12)
154         >>> t.insert(33)
155         >>> t.insert(4)
156         >>> t.insert(88)
157         >>> t
158         50
159         ├──25
160         │  ├──12
161         │  │  └──4
162         │  └──33
163         └──75
164            └──88
165
166         >>> n = t[4]
167         >>> for x in t.parent_path(n):
168         ...     print(x.value)
169         50
170         25
171         12
172         4
173
174         >>> del t[4]
175         >>> for x in t.parent_path(n):
176         ...     if x is not None:
177         ...         print(x.value)
178         ...     else:
179         ...         print(x)
180         50
181         25
182         12
183         None
184
185         """
186         return self._parent_path(self.root, node)
187
188     def __delitem__(self, value: Any) -> bool:
189         """
190         Delete an item from the tree and preserve the BST property.
191
192         Args:
193             value: the value of the node to be deleted.
194
195         Returns:
196             True if the value was found and its associated node was
197             successfully deleted and False otherwise.
198
199         >>> t = BinarySearchTree()
200         >>> t.insert(50)
201         >>> t.insert(75)
202         >>> t.insert(25)
203         >>> t.insert(66)
204         >>> t.insert(22)
205         >>> t.insert(13)
206         >>> t.insert(85)
207         >>> t
208         50
209         ├──25
210         │  └──22
211         │     └──13
212         └──75
213            ├──66
214            └──85
215
216         >>> for value in t.iterate_inorder():
217         ...     print(value)
218         13
219         22
220         25
221         50
222         66
223         75
224         85
225
226         >>> del t[22]  # Note: bool result is discarded
227
228         >>> for value in t.iterate_inorder():
229         ...     print(value)
230         13
231         25
232         50
233         66
234         75
235         85
236
237         >>> t.__delitem__(13)
238         True
239         >>> for value in t.iterate_inorder():
240         ...     print(value)
241         25
242         50
243         66
244         75
245         85
246
247         >>> t.__delitem__(75)
248         True
249         >>> for value in t.iterate_inorder():
250         ...     print(value)
251         25
252         50
253         66
254         85
255         >>> t
256         50
257         ├──25
258         └──85
259            └──66
260
261         >>> t.__delitem__(99)
262         False
263
264         """
265         if self.root is not None:
266             ret = self._delete(value, None, self.root)
267             if ret:
268                 self.count -= 1
269                 if self.count == 0:
270                     self.root = None
271             return ret
272         return False
273
274     def _on_delete(self, parent: Optional[Node], deleted: Node) -> None:
275         """This is called just after deleted was deleted from the tree"""
276         pass
277
278     def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool:
279         """Delete helper"""
280         if node.value == value:
281
282             # Deleting a leaf node
283             if node.left is None and node.right is None:
284                 if parent is not None:
285                     if parent.left == node:
286                         parent.left = None
287                     else:
288                         assert parent.right == node
289                         parent.right = None
290                 self._on_delete(parent, node)
291                 return True
292
293             # Node only has a right.
294             elif node.left is None:
295                 assert node.right is not None
296                 if parent is not None:
297                     if parent.left == node:
298                         parent.left = node.right
299                     else:
300                         assert parent.right == node
301                         parent.right = node.right
302                 self._on_delete(parent, node)
303                 return True
304
305             # Node only has a left.
306             elif node.right is None:
307                 assert node.left is not None
308                 if parent is not None:
309                     if parent.left == node:
310                         parent.left = node.left
311                     else:
312                         assert parent.right == node
313                         parent.right = node.left
314                 self._on_delete(parent, node)
315                 return True
316
317             # Node has both a left and right.
318             else:
319                 assert node.left is not None and node.right is not None
320                 descendent = node.right
321                 while descendent.left is not None:
322                     descendent = descendent.left
323                 node.value = descendent.value
324                 return self._delete(node.value, node, node.right)
325         elif value < node.value and node.left is not None:
326             return self._delete(value, node, node.left)
327         elif value > node.value and node.right is not None:
328             return self._delete(value, node, node.right)
329         return False
330
331     def __len__(self):
332         """
333         Returns:
334             The count of items in the tree.
335
336         >>> t = BinarySearchTree()
337         >>> len(t)
338         0
339         >>> t.insert(50)
340         >>> len(t)
341         1
342         >>> t.__delitem__(50)
343         True
344         >>> len(t)
345         0
346         >>> t.insert(75)
347         >>> t.insert(25)
348         >>> t.insert(66)
349         >>> t.insert(22)
350         >>> t.insert(13)
351         >>> t.insert(85)
352         >>> len(t)
353         6
354
355         """
356         return self.count
357
358     def __contains__(self, value: Any) -> bool:
359         """
360         Returns:
361             True if the item is in the tree; False otherwise.
362         """
363         return self.__getitem__(value) is not None
364
365     def _iterate_preorder(self, node: Node):
366         yield node.value
367         if node.left is not None:
368             yield from self._iterate_preorder(node.left)
369         if node.right is not None:
370             yield from self._iterate_preorder(node.right)
371
372     def _iterate_inorder(self, node: Node):
373         if node.left is not None:
374             yield from self._iterate_inorder(node.left)
375         yield node.value
376         if node.right is not None:
377             yield from self._iterate_inorder(node.right)
378
379     def _iterate_postorder(self, node: Node):
380         if node.left is not None:
381             yield from self._iterate_postorder(node.left)
382         if node.right is not None:
383             yield from self._iterate_postorder(node.right)
384         yield node.value
385
386     def iterate_preorder(self):
387         """
388         Returns:
389             A Generator that yields the tree's items in a
390             preorder traversal sequence.
391
392         >>> t = BinarySearchTree()
393         >>> t.insert(50)
394         >>> t.insert(75)
395         >>> t.insert(25)
396         >>> t.insert(66)
397         >>> t.insert(22)
398         >>> t.insert(13)
399
400         >>> for value in t.iterate_preorder():
401         ...     print(value)
402         50
403         25
404         22
405         13
406         75
407         66
408
409         """
410         if self.root is not None:
411             yield from self._iterate_preorder(self.root)
412
413     def iterate_inorder(self):
414         """
415         Returns:
416             A Generator that yield the tree's items in a preorder
417             traversal sequence.
418
419         >>> t = BinarySearchTree()
420         >>> t.insert(50)
421         >>> t.insert(75)
422         >>> t.insert(25)
423         >>> t.insert(66)
424         >>> t.insert(22)
425         >>> t.insert(13)
426         >>> t.insert(24)
427         >>> t
428         50
429         ├──25
430         │  └──22
431         │     ├──13
432         │     └──24
433         └──75
434            └──66
435
436         >>> for value in t.iterate_inorder():
437         ...     print(value)
438         13
439         22
440         24
441         25
442         50
443         66
444         75
445
446         """
447         if self.root is not None:
448             yield from self._iterate_inorder(self.root)
449
450     def iterate_postorder(self):
451         """
452         Returns:
453             A Generator that yield the tree's items in a preorder
454             traversal sequence.
455
456         >>> t = BinarySearchTree()
457         >>> t.insert(50)
458         >>> t.insert(75)
459         >>> t.insert(25)
460         >>> t.insert(66)
461         >>> t.insert(22)
462         >>> t.insert(13)
463
464         >>> for value in t.iterate_postorder():
465         ...     print(value)
466         13
467         22
468         25
469         66
470         75
471         50
472
473         """
474         if self.root is not None:
475             yield from self._iterate_postorder(self.root)
476
477     def _iterate_leaves(self, node: Node):
478         if node.left is not None:
479             yield from self._iterate_leaves(node.left)
480         if node.right is not None:
481             yield from self._iterate_leaves(node.right)
482         if node.left is None and node.right is None:
483             yield node.value
484
485     def iterate_leaves(self):
486         """
487         Returns:
488             A Gemerator that yielde only the leaf nodes in the
489             tree.
490
491         >>> t = BinarySearchTree()
492         >>> t.insert(50)
493         >>> t.insert(75)
494         >>> t.insert(25)
495         >>> t.insert(66)
496         >>> t.insert(22)
497         >>> t.insert(13)
498
499         >>> for value in t.iterate_leaves():
500         ...     print(value)
501         13
502         66
503
504         """
505         if self.root is not None:
506             yield from self._iterate_leaves(self.root)
507
508     def _iterate_by_depth(self, node: Node, depth: int):
509         if depth == 0:
510             yield node.value
511         else:
512             assert depth > 0
513             if node.left is not None:
514                 yield from self._iterate_by_depth(node.left, depth - 1)
515             if node.right is not None:
516                 yield from self._iterate_by_depth(node.right, depth - 1)
517
518     def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
519         """
520         Args:
521             depth: the desired depth
522
523         Returns:
524             A Generator that yields nodes at the prescribed depth in
525             the tree.
526
527         >>> t = BinarySearchTree()
528         >>> t.insert(50)
529         >>> t.insert(75)
530         >>> t.insert(25)
531         >>> t.insert(66)
532         >>> t.insert(22)
533         >>> t.insert(13)
534
535         >>> for value in t.iterate_nodes_by_depth(2):
536         ...     print(value)
537         22
538         66
539
540         >>> for value in t.iterate_nodes_by_depth(3):
541         ...     print(value)
542         13
543
544         """
545         if self.root is not None:
546             yield from self._iterate_by_depth(self.root, depth)
547
548     def get_next_node(self, node: Node) -> Node:
549         """
550         Args:
551             node: the node whose next greater successor is desired
552
553         Returns:
554             Given a tree node, returns the next greater node in the tree.
555
556         >>> t = BinarySearchTree()
557         >>> t.insert(50)
558         >>> t.insert(75)
559         >>> t.insert(25)
560         >>> t.insert(66)
561         >>> t.insert(22)
562         >>> t.insert(13)
563         >>> t.insert(23)
564         >>> t
565         50
566         ├──25
567         │  └──22
568         │     ├──13
569         │     └──23
570         └──75
571            └──66
572
573         >>> n = t[23]
574         >>> t.get_next_node(n).value
575         25
576
577         >>> n = t[50]
578         >>> t.get_next_node(n).value
579         66
580
581         """
582         if node.right is not None:
583             x = node.right
584             while x.left is not None:
585                 x = x.left
586             return x
587
588         path = self.parent_path(node)
589         assert path[-1] is not None
590         assert path[-1] == node
591         path = path[:-1]
592         path.reverse()
593         for ancestor in path:
594             assert ancestor is not None
595             if node != ancestor.right:
596                 return ancestor
597             node = ancestor
598         raise Exception()
599
600     def _depth(self, node: Node, sofar: int) -> int:
601         depth_left = sofar + 1
602         depth_right = sofar + 1
603         if node.left is not None:
604             depth_left = self._depth(node.left, sofar + 1)
605         if node.right is not None:
606             depth_right = self._depth(node.right, sofar + 1)
607         return max(depth_left, depth_right)
608
609     def depth(self) -> int:
610         """
611         Returns:
612             The max height (depth) of the tree in plies (edge distance
613             from root).
614
615         >>> t = BinarySearchTree()
616         >>> t.depth()
617         0
618
619         >>> t.insert(50)
620         >>> t.depth()
621         1
622
623         >>> t.insert(65)
624         >>> t.depth()
625         2
626
627         >>> t.insert(33)
628         >>> t.depth()
629         2
630
631         >>> t.insert(2)
632         >>> t.insert(1)
633         >>> t.depth()
634         4
635
636         """
637         if self.root is None:
638             return 0
639         return self._depth(self.root, 0)
640
641     def height(self) -> int:
642         """Returns the height (i.e. max depth) of the tree"""
643         return self.depth()
644
645     def repr_traverse(
646         self,
647         padding: str,
648         pointer: str,
649         node: Optional[Node],
650         has_right_sibling: bool,
651     ) -> str:
652         if node is not None:
653             viz = f"\n{padding}{pointer}{node.value}"
654             if has_right_sibling:
655                 padding += "│  "
656             else:
657                 padding += "   "
658
659             pointer_right = "└──"
660             if node.right is not None:
661                 pointer_left = "├──"
662             else:
663                 pointer_left = "└──"
664
665             viz += self.repr_traverse(
666                 padding, pointer_left, node.left, node.right is not None
667             )
668             viz += self.repr_traverse(padding, pointer_right, node.right, False)
669             return viz
670         return ""
671
672     def __repr__(self):
673         """
674         Returns:
675             An ASCII string representation of the tree.
676
677         >>> t = BinarySearchTree()
678         >>> t.insert(50)
679         >>> t.insert(25)
680         >>> t.insert(75)
681         >>> t.insert(12)
682         >>> t.insert(33)
683         >>> t.insert(88)
684         >>> t.insert(55)
685         >>> t
686         50
687         ├──25
688         │  ├──12
689         │  └──33
690         └──75
691            ├──55
692            └──88
693         """
694         if self.root is None:
695             return ""
696
697         ret = f"{self.root.value}"
698         pointer_right = "└──"
699         if self.root.right is None:
700             pointer_left = "└──"
701         else:
702             pointer_left = "├──"
703
704         ret += self.repr_traverse(
705             "", pointer_left, self.root.left, self.root.left is not None
706         )
707         ret += self.repr_traverse("", pointer_right, self.root.right, False)
708         return ret
709
710
711 if __name__ == "__main__":
712     import doctest
713
714     doctest.testmod()