More work to improve documentation generated by sphinx. Also fixes
[pyutils.git] / src / pyutils / collectionz / bst.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """A binary search tree implementation."""
6
7 from typing import Any, Generator, List, Optional
8
9
10 class Node(object):
11     def __init__(self, value: Any) -> None:
12         """
13         A BST node.  Note that value can be anything as long as it
14         is comparable.  Check out :meth:`functools.total_ordering`
15         (https://docs.python.org/3/library/functools.html#functools.total_ordering)
16
17         Args:
18             value: a reference to the value of the node.
19         """
20         self.left: Optional[Node] = None
21         self.right: Optional[Node] = None
22         self.value = value
23
24
25 class BinarySearchTree(object):
26     def __init__(self):
27         self.root = None
28         self.count = 0
29         self.traverse = None
30
31     def get_root(self) -> Optional[Node]:
32         """
33         Returns:
34             The root of the BST
35         """
36
37         return self.root
38
39     def insert(self, value: Any) -> None:
40         """
41         Insert something into the tree.
42
43         Args:
44             value: the value to be inserted.
45
46         >>> t = BinarySearchTree()
47         >>> t.insert(10)
48         >>> t.insert(20)
49         >>> t.insert(5)
50         >>> len(t)
51         3
52
53         >>> t.get_root().value
54         10
55
56         """
57         if self.root is None:
58             self.root = Node(value)
59             self.count = 1
60         else:
61             self._insert(value, self.root)
62
63     def _insert(self, value: Any, node: Node):
64         """Insertion helper"""
65         if value < node.value:
66             if node.left is not None:
67                 self._insert(value, node.left)
68             else:
69                 node.left = Node(value)
70                 self.count += 1
71         else:
72             if node.right is not None:
73                 self._insert(value, node.right)
74             else:
75                 node.right = Node(value)
76                 self.count += 1
77
78     def __getitem__(self, value: Any) -> Optional[Node]:
79         """
80         Find an item in the tree and return its Node.  Returns
81         None if the item is not in the tree.
82
83         >>> t = BinarySearchTree()
84         >>> t[99]
85
86         >>> t.insert(10)
87         >>> t.insert(20)
88         >>> t.insert(5)
89         >>> t[10].value
90         10
91
92         >>> t[99]
93
94         """
95         if self.root is not None:
96             return self._find(value, self.root)
97         return None
98
99     def _find(self, value: Any, node: Node) -> Optional[Node]:
100         """Find helper"""
101         if value == node.value:
102             return node
103         elif value < node.value and node.left is not None:
104             return self._find(value, node.left)
105         elif value > node.value and node.right is not None:
106             return self._find(value, node.right)
107         return None
108
109     def _parent_path(
110         self, current: Optional[Node], target: Node
111     ) -> List[Optional[Node]]:
112         """Internal helper"""
113         if current is None:
114             return [None]
115         ret: List[Optional[Node]] = [current]
116         if target.value == current.value:
117             return ret
118         elif target.value < current.value:
119             ret.extend(self._parent_path(current.left, target))
120             return ret
121         else:
122             assert target.value > current.value
123             ret.extend(self._parent_path(current.right, target))
124             return ret
125
126     def parent_path(self, node: Node) -> List[Optional[Node]]:
127         """Get a node's parent path.
128
129         Args:
130             node: the node to check
131
132         Returns:
133             a list of nodes representing the path from
134             the tree's root to the node.
135
136         .. note::
137
138             If the node does not exist in the tree, the last element
139             on the path will be None but the path will indicate the
140             ancestor path of that node were it to be inserted.
141
142         >>> t = BinarySearchTree()
143         >>> t.insert(50)
144         >>> t.insert(75)
145         >>> t.insert(25)
146         >>> t.insert(12)
147         >>> t.insert(33)
148         >>> t.insert(4)
149         >>> t.insert(88)
150         >>> t
151         50
152         ├──25
153         │  ├──12
154         │  │  └──4
155         │  └──33
156         └──75
157            └──88
158
159         >>> n = t[4]
160         >>> for x in t.parent_path(n):
161         ...     print(x.value)
162         50
163         25
164         12
165         4
166
167         >>> del t[4]
168         >>> for x in t.parent_path(n):
169         ...     if x is not None:
170         ...         print(x.value)
171         ...     else:
172         ...         print(x)
173         50
174         25
175         12
176         None
177
178         """
179         return self._parent_path(self.root, node)
180
181     def __delitem__(self, value: Any) -> bool:
182         """
183         Delete an item from the tree and preserve the BST property.
184
185         Args:
186             value: the value of the node to be deleted.
187
188         Returns:
189             True if the value was found and its associated node was
190             successfully deleted and False otherwise.
191
192         >>> t = BinarySearchTree()
193         >>> t.insert(50)
194         >>> t.insert(75)
195         >>> t.insert(25)
196         >>> t.insert(66)
197         >>> t.insert(22)
198         >>> t.insert(13)
199         >>> t.insert(85)
200         >>> t
201         50
202         ├──25
203         │  └──22
204         │     └──13
205         └──75
206            ├──66
207            └──85
208
209         >>> for value in t.iterate_inorder():
210         ...     print(value)
211         13
212         22
213         25
214         50
215         66
216         75
217         85
218
219         >>> del t[22]  # Note: bool result is discarded
220
221         >>> for value in t.iterate_inorder():
222         ...     print(value)
223         13
224         25
225         50
226         66
227         75
228         85
229
230         >>> t.__delitem__(13)
231         True
232         >>> for value in t.iterate_inorder():
233         ...     print(value)
234         25
235         50
236         66
237         75
238         85
239
240         >>> t.__delitem__(75)
241         True
242         >>> for value in t.iterate_inorder():
243         ...     print(value)
244         25
245         50
246         66
247         85
248         >>> t
249         50
250         ├──25
251         └──85
252            └──66
253
254         >>> t.__delitem__(99)
255         False
256
257         """
258         if self.root is not None:
259             ret = self._delete(value, None, self.root)
260             if ret:
261                 self.count -= 1
262                 if self.count == 0:
263                     self.root = None
264             return ret
265         return False
266
267     def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool:
268         """Delete helper"""
269         if node.value == value:
270             # Deleting a leaf node
271             if node.left is None and node.right is None:
272                 if parent is not None:
273                     if parent.left == node:
274                         parent.left = None
275                     else:
276                         assert parent.right == node
277                         parent.right = None
278                 return True
279
280             # Node only has a right.
281             elif node.left is None:
282                 assert node.right is not None
283                 if parent is not None:
284                     if parent.left == node:
285                         parent.left = node.right
286                     else:
287                         assert parent.right == node
288                         parent.right = node.right
289                 return True
290
291             # Node only has a left.
292             elif node.right is None:
293                 assert node.left is not None
294                 if parent is not None:
295                     if parent.left == node:
296                         parent.left = node.left
297                     else:
298                         assert parent.right == node
299                         parent.right = node.left
300                 return True
301
302             # Node has both a left and right.
303             else:
304                 assert node.left is not None and node.right is not None
305                 descendent = node.right
306                 while descendent.left is not None:
307                     descendent = descendent.left
308                 node.value = descendent.value
309                 return self._delete(node.value, node, node.right)
310         elif value < node.value and node.left is not None:
311             return self._delete(value, node, node.left)
312         elif value > node.value and node.right is not None:
313             return self._delete(value, node, node.right)
314         return False
315
316     def __len__(self):
317         """
318         Returns:
319             The count of items in the tree.
320
321         >>> t = BinarySearchTree()
322         >>> len(t)
323         0
324         >>> t.insert(50)
325         >>> len(t)
326         1
327         >>> t.__delitem__(50)
328         True
329         >>> len(t)
330         0
331         >>> t.insert(75)
332         >>> t.insert(25)
333         >>> t.insert(66)
334         >>> t.insert(22)
335         >>> t.insert(13)
336         >>> t.insert(85)
337         >>> len(t)
338         6
339
340         """
341         return self.count
342
343     def __contains__(self, value: Any) -> bool:
344         """
345         Returns:
346             True if the item is in the tree; False otherwise.
347         """
348         return self.__getitem__(value) is not None
349
350     def _iterate_preorder(self, node: Node):
351         yield node.value
352         if node.left is not None:
353             yield from self._iterate_preorder(node.left)
354         if node.right is not None:
355             yield from self._iterate_preorder(node.right)
356
357     def _iterate_inorder(self, node: Node):
358         if node.left is not None:
359             yield from self._iterate_inorder(node.left)
360         yield node.value
361         if node.right is not None:
362             yield from self._iterate_inorder(node.right)
363
364     def _iterate_postorder(self, node: Node):
365         if node.left is not None:
366             yield from self._iterate_postorder(node.left)
367         if node.right is not None:
368             yield from self._iterate_postorder(node.right)
369         yield node.value
370
371     def iterate_preorder(self):
372         """
373         Returns:
374             A Generator that yields the tree's items in a
375             preorder traversal sequence.
376
377         >>> t = BinarySearchTree()
378         >>> t.insert(50)
379         >>> t.insert(75)
380         >>> t.insert(25)
381         >>> t.insert(66)
382         >>> t.insert(22)
383         >>> t.insert(13)
384
385         >>> for value in t.iterate_preorder():
386         ...     print(value)
387         50
388         25
389         22
390         13
391         75
392         66
393
394         """
395         if self.root is not None:
396             yield from self._iterate_preorder(self.root)
397
398     def iterate_inorder(self):
399         """
400         Returns:
401             A Generator that yield the tree's items in a preorder
402             traversal sequence.
403
404         >>> t = BinarySearchTree()
405         >>> t.insert(50)
406         >>> t.insert(75)
407         >>> t.insert(25)
408         >>> t.insert(66)
409         >>> t.insert(22)
410         >>> t.insert(13)
411         >>> t.insert(24)
412         >>> t
413         50
414         ├──25
415         │  └──22
416         │     ├──13
417         │     └──24
418         └──75
419            └──66
420
421         >>> for value in t.iterate_inorder():
422         ...     print(value)
423         13
424         22
425         24
426         25
427         50
428         66
429         75
430
431         """
432         if self.root is not None:
433             yield from self._iterate_inorder(self.root)
434
435     def iterate_postorder(self):
436         """
437         Returns:
438             A Generator that yield the tree's items in a preorder
439             traversal sequence.
440
441         >>> t = BinarySearchTree()
442         >>> t.insert(50)
443         >>> t.insert(75)
444         >>> t.insert(25)
445         >>> t.insert(66)
446         >>> t.insert(22)
447         >>> t.insert(13)
448
449         >>> for value in t.iterate_postorder():
450         ...     print(value)
451         13
452         22
453         25
454         66
455         75
456         50
457
458         """
459         if self.root is not None:
460             yield from self._iterate_postorder(self.root)
461
462     def _iterate_leaves(self, node: Node):
463         if node.left is not None:
464             yield from self._iterate_leaves(node.left)
465         if node.right is not None:
466             yield from self._iterate_leaves(node.right)
467         if node.left is None and node.right is None:
468             yield node.value
469
470     def iterate_leaves(self):
471         """
472         Returns:
473             A Gemerator that yielde only the leaf nodes in the
474             tree.
475
476         >>> t = BinarySearchTree()
477         >>> t.insert(50)
478         >>> t.insert(75)
479         >>> t.insert(25)
480         >>> t.insert(66)
481         >>> t.insert(22)
482         >>> t.insert(13)
483
484         >>> for value in t.iterate_leaves():
485         ...     print(value)
486         13
487         66
488
489         """
490         if self.root is not None:
491             yield from self._iterate_leaves(self.root)
492
493     def _iterate_by_depth(self, node: Node, depth: int):
494         if depth == 0:
495             yield node.value
496         else:
497             assert depth > 0
498             if node.left is not None:
499                 yield from self._iterate_by_depth(node.left, depth - 1)
500             if node.right is not None:
501                 yield from self._iterate_by_depth(node.right, depth - 1)
502
503     def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
504         """
505         Args:
506             depth: the desired depth
507
508         Returns:
509             A Generator that yields nodes at the prescribed depth in
510             the tree.
511
512         >>> t = BinarySearchTree()
513         >>> t.insert(50)
514         >>> t.insert(75)
515         >>> t.insert(25)
516         >>> t.insert(66)
517         >>> t.insert(22)
518         >>> t.insert(13)
519
520         >>> for value in t.iterate_nodes_by_depth(2):
521         ...     print(value)
522         22
523         66
524
525         >>> for value in t.iterate_nodes_by_depth(3):
526         ...     print(value)
527         13
528
529         """
530         if self.root is not None:
531             yield from self._iterate_by_depth(self.root, depth)
532
533     def get_next_node(self, node: Node) -> Node:
534         """
535         Args:
536             node: the node whose next greater successor is desired
537
538         Returns:
539             Given a tree node, returns the next greater node in the tree.
540
541         >>> t = BinarySearchTree()
542         >>> t.insert(50)
543         >>> t.insert(75)
544         >>> t.insert(25)
545         >>> t.insert(66)
546         >>> t.insert(22)
547         >>> t.insert(13)
548         >>> t.insert(23)
549         >>> t
550         50
551         ├──25
552         │  └──22
553         │     ├──13
554         │     └──23
555         └──75
556            └──66
557
558         >>> n = t[23]
559         >>> t.get_next_node(n).value
560         25
561
562         >>> n = t[50]
563         >>> t.get_next_node(n).value
564         66
565
566         """
567         if node.right is not None:
568             x = node.right
569             while x.left is not None:
570                 x = x.left
571             return x
572
573         path = self.parent_path(node)
574         assert path[-1] is not None
575         assert path[-1] == node
576         path = path[:-1]
577         path.reverse()
578         for ancestor in path:
579             assert ancestor is not None
580             if node != ancestor.right:
581                 return ancestor
582             node = ancestor
583         raise Exception()
584
585     def _depth(self, node: Node, sofar: int) -> int:
586         depth_left = sofar + 1
587         depth_right = sofar + 1
588         if node.left is not None:
589             depth_left = self._depth(node.left, sofar + 1)
590         if node.right is not None:
591             depth_right = self._depth(node.right, sofar + 1)
592         return max(depth_left, depth_right)
593
594     def depth(self) -> int:
595         """
596         Returns:
597             The max height (depth) of the tree in plies (edge distance
598             from root).
599
600         >>> t = BinarySearchTree()
601         >>> t.depth()
602         0
603
604         >>> t.insert(50)
605         >>> t.depth()
606         1
607
608         >>> t.insert(65)
609         >>> t.depth()
610         2
611
612         >>> t.insert(33)
613         >>> t.depth()
614         2
615
616         >>> t.insert(2)
617         >>> t.insert(1)
618         >>> t.depth()
619         4
620
621         """
622         if self.root is None:
623             return 0
624         return self._depth(self.root, 0)
625
626     def height(self) -> int:
627         """Returns the height (i.e. max depth) of the tree"""
628         return self.depth()
629
630     def repr_traverse(
631         self,
632         padding: str,
633         pointer: str,
634         node: Optional[Node],
635         has_right_sibling: bool,
636     ) -> str:
637         if node is not None:
638             viz = f'\n{padding}{pointer}{node.value}'
639             if has_right_sibling:
640                 padding += "│  "
641             else:
642                 padding += '   '
643
644             pointer_right = "└──"
645             if node.right is not None:
646                 pointer_left = "├──"
647             else:
648                 pointer_left = "└──"
649
650             viz += self.repr_traverse(
651                 padding, pointer_left, node.left, node.right is not None
652             )
653             viz += self.repr_traverse(padding, pointer_right, node.right, False)
654             return viz
655         return ""
656
657     def __repr__(self):
658         """
659         Returns:
660             An ASCII string representation of the tree.
661
662         >>> t = BinarySearchTree()
663         >>> t.insert(50)
664         >>> t.insert(25)
665         >>> t.insert(75)
666         >>> t.insert(12)
667         >>> t.insert(33)
668         >>> t.insert(88)
669         >>> t.insert(55)
670         >>> t
671         50
672         ├──25
673         │  ├──12
674         │  └──33
675         └──75
676            ├──55
677            └──88
678         """
679         if self.root is None:
680             return ""
681
682         ret = f'{self.root.value}'
683         pointer_right = "└──"
684         if self.root.right is None:
685             pointer_left = "└──"
686         else:
687             pointer_left = "├──"
688
689         ret += self.repr_traverse(
690             '', pointer_left, self.root.left, self.root.left is not None
691         )
692         ret += self.repr_traverse('', pointer_right, self.root.right, False)
693         return ret
694
695
696 if __name__ == '__main__':
697     import doctest
698
699     doctest.testmod()