Migration from old pyutilz package name (which, in turn, came from
[pyutils.git] / src / pyutils / collectionz / bst.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """A binary search tree."""
6
7 from typing import Any, Generator, List, Optional
8
9
10 class Node(object):
11     def __init__(self, value: Any) -> None:
12         """
13         Note: value can be anything as long as it is comparable.
14         Check out @functools.total_ordering.
15         """
16         self.left: Optional[Node] = None
17         self.right: Optional[Node] = None
18         self.value = value
19
20
21 class BinarySearchTree(object):
22     def __init__(self):
23         self.root = None
24         self.count = 0
25         self.traverse = None
26
27     def get_root(self) -> Optional[Node]:
28         return self.root
29
30     def insert(self, value: Any):
31         """
32         Insert something into the tree.
33
34         >>> t = BinarySearchTree()
35         >>> t.insert(10)
36         >>> t.insert(20)
37         >>> t.insert(5)
38         >>> len(t)
39         3
40
41         >>> t.get_root().value
42         10
43
44         """
45         if self.root is None:
46             self.root = Node(value)
47             self.count = 1
48         else:
49             self._insert(value, self.root)
50
51     def _insert(self, value: Any, node: Node):
52         """Insertion helper"""
53         if value < node.value:
54             if node.left is not None:
55                 self._insert(value, node.left)
56             else:
57                 node.left = Node(value)
58                 self.count += 1
59         else:
60             if node.right is not None:
61                 self._insert(value, node.right)
62             else:
63                 node.right = Node(value)
64                 self.count += 1
65
66     def __getitem__(self, value: Any) -> Optional[Node]:
67         """
68         Find an item in the tree and return its Node.  Returns
69         None if the item is not in the tree.
70
71         >>> t = BinarySearchTree()
72         >>> t[99]
73
74         >>> t.insert(10)
75         >>> t.insert(20)
76         >>> t.insert(5)
77         >>> t[10].value
78         10
79
80         >>> t[99]
81
82         """
83         if self.root is not None:
84             return self._find(value, self.root)
85         return None
86
87     def _find(self, value: Any, node: Node) -> Optional[Node]:
88         """Find helper"""
89         if value == node.value:
90             return node
91         elif value < node.value and node.left is not None:
92             return self._find(value, node.left)
93         elif value > node.value and node.right is not None:
94             return self._find(value, node.right)
95         return None
96
97     def _parent_path(
98         self, current: Optional[Node], target: Node
99     ) -> List[Optional[Node]]:
100         if current is None:
101             return [None]
102         ret: List[Optional[Node]] = [current]
103         if target.value == current.value:
104             return ret
105         elif target.value < current.value:
106             ret.extend(self._parent_path(current.left, target))
107             return ret
108         else:
109             assert target.value > current.value
110             ret.extend(self._parent_path(current.right, target))
111             return ret
112
113     def parent_path(self, node: Node) -> List[Optional[Node]]:
114         """Return a list of nodes representing the path from
115         the tree's root to the node argument.  If the node does
116         not exist in the tree for some reason, the last element
117         on the path will be None but the path will indicate the
118         ancestor path of that node were it inserted.
119
120         >>> t = BinarySearchTree()
121         >>> t.insert(50)
122         >>> t.insert(75)
123         >>> t.insert(25)
124         >>> t.insert(12)
125         >>> t.insert(33)
126         >>> t.insert(4)
127         >>> t.insert(88)
128         >>> t
129         50
130         ├──25
131         │  ├──12
132         │  │  └──4
133         │  └──33
134         └──75
135            └──88
136
137         >>> n = t[4]
138         >>> for x in t.parent_path(n):
139         ...     print(x.value)
140         50
141         25
142         12
143         4
144
145         >>> del t[4]
146         >>> for x in t.parent_path(n):
147         ...     if x is not None:
148         ...         print(x.value)
149         ...     else:
150         ...         print(x)
151         50
152         25
153         12
154         None
155
156         """
157         return self._parent_path(self.root, node)
158
159     def __delitem__(self, value: Any) -> bool:
160         """
161         Delete an item from the tree and preserve the BST property.
162
163         >>> t = BinarySearchTree()
164         >>> t.insert(50)
165         >>> t.insert(75)
166         >>> t.insert(25)
167         >>> t.insert(66)
168         >>> t.insert(22)
169         >>> t.insert(13)
170         >>> t.insert(85)
171         >>> t
172         50
173         ├──25
174         │  └──22
175         │     └──13
176         └──75
177            ├──66
178            └──85
179
180         >>> for value in t.iterate_inorder():
181         ...     print(value)
182         13
183         22
184         25
185         50
186         66
187         75
188         85
189
190         >>> del t[22]  # Note: bool result is discarded
191
192         >>> for value in t.iterate_inorder():
193         ...     print(value)
194         13
195         25
196         50
197         66
198         75
199         85
200
201         >>> t.__delitem__(13)
202         True
203         >>> for value in t.iterate_inorder():
204         ...     print(value)
205         25
206         50
207         66
208         75
209         85
210
211         >>> t.__delitem__(75)
212         True
213         >>> for value in t.iterate_inorder():
214         ...     print(value)
215         25
216         50
217         66
218         85
219         >>> t
220         50
221         ├──25
222         └──85
223            └──66
224
225         >>> t.__delitem__(99)
226         False
227
228         """
229         if self.root is not None:
230             ret = self._delete(value, None, self.root)
231             if ret:
232                 self.count -= 1
233                 if self.count == 0:
234                     self.root = None
235             return ret
236         return False
237
238     def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool:
239         """Delete helper"""
240         if node.value == value:
241             # Deleting a leaf node
242             if node.left is None and node.right is None:
243                 if parent is not None:
244                     if parent.left == node:
245                         parent.left = None
246                     else:
247                         assert parent.right == node
248                         parent.right = None
249                 return True
250
251             # Node only has a right.
252             elif node.left is None:
253                 assert node.right is not None
254                 if parent is not None:
255                     if parent.left == node:
256                         parent.left = node.right
257                     else:
258                         assert parent.right == node
259                         parent.right = node.right
260                 return True
261
262             # Node only has a left.
263             elif node.right is None:
264                 assert node.left is not None
265                 if parent is not None:
266                     if parent.left == node:
267                         parent.left = node.left
268                     else:
269                         assert parent.right == node
270                         parent.right = node.left
271                 return True
272
273             # Node has both a left and right.
274             else:
275                 assert node.left is not None and node.right is not None
276                 descendent = node.right
277                 while descendent.left is not None:
278                     descendent = descendent.left
279                 node.value = descendent.value
280                 return self._delete(node.value, node, node.right)
281         elif value < node.value and node.left is not None:
282             return self._delete(value, node, node.left)
283         elif value > node.value and node.right is not None:
284             return self._delete(value, node, node.right)
285         return False
286
287     def __len__(self):
288         """
289         Returns the count of items in the tree.
290
291         >>> t = BinarySearchTree()
292         >>> len(t)
293         0
294         >>> t.insert(50)
295         >>> len(t)
296         1
297         >>> t.__delitem__(50)
298         True
299         >>> len(t)
300         0
301         >>> t.insert(75)
302         >>> t.insert(25)
303         >>> t.insert(66)
304         >>> t.insert(22)
305         >>> t.insert(13)
306         >>> t.insert(85)
307         >>> len(t)
308         6
309
310         """
311         return self.count
312
313     def __contains__(self, value: Any) -> bool:
314         """
315         Returns True if the item is in the tree; False otherwise.
316
317         """
318         return self.__getitem__(value) is not None
319
320     def _iterate_preorder(self, node: Node):
321         yield node.value
322         if node.left is not None:
323             yield from self._iterate_preorder(node.left)
324         if node.right is not None:
325             yield from self._iterate_preorder(node.right)
326
327     def _iterate_inorder(self, node: Node):
328         if node.left is not None:
329             yield from self._iterate_inorder(node.left)
330         yield node.value
331         if node.right is not None:
332             yield from self._iterate_inorder(node.right)
333
334     def _iterate_postorder(self, node: Node):
335         if node.left is not None:
336             yield from self._iterate_postorder(node.left)
337         if node.right is not None:
338             yield from self._iterate_postorder(node.right)
339         yield node.value
340
341     def iterate_preorder(self):
342         """
343         Yield the tree's items in a preorder traversal sequence.
344
345         >>> t = BinarySearchTree()
346         >>> t.insert(50)
347         >>> t.insert(75)
348         >>> t.insert(25)
349         >>> t.insert(66)
350         >>> t.insert(22)
351         >>> t.insert(13)
352
353         >>> for value in t.iterate_preorder():
354         ...     print(value)
355         50
356         25
357         22
358         13
359         75
360         66
361
362         """
363         if self.root is not None:
364             yield from self._iterate_preorder(self.root)
365
366     def iterate_inorder(self):
367         """
368         Yield the tree's items in a preorder traversal sequence.
369
370         >>> t = BinarySearchTree()
371         >>> t.insert(50)
372         >>> t.insert(75)
373         >>> t.insert(25)
374         >>> t.insert(66)
375         >>> t.insert(22)
376         >>> t.insert(13)
377         >>> t.insert(24)
378         >>> t
379         50
380         ├──25
381         │  └──22
382         │     ├──13
383         │     └──24
384         └──75
385            └──66
386
387         >>> for value in t.iterate_inorder():
388         ...     print(value)
389         13
390         22
391         24
392         25
393         50
394         66
395         75
396
397         """
398         if self.root is not None:
399             yield from self._iterate_inorder(self.root)
400
401     def iterate_postorder(self):
402         """
403         Yield the tree's items in a preorder traversal sequence.
404
405         >>> t = BinarySearchTree()
406         >>> t.insert(50)
407         >>> t.insert(75)
408         >>> t.insert(25)
409         >>> t.insert(66)
410         >>> t.insert(22)
411         >>> t.insert(13)
412
413         >>> for value in t.iterate_postorder():
414         ...     print(value)
415         13
416         22
417         25
418         66
419         75
420         50
421
422         """
423         if self.root is not None:
424             yield from self._iterate_postorder(self.root)
425
426     def _iterate_leaves(self, node: Node):
427         if node.left is not None:
428             yield from self._iterate_leaves(node.left)
429         if node.right is not None:
430             yield from self._iterate_leaves(node.right)
431         if node.left is None and node.right is None:
432             yield node.value
433
434     def iterate_leaves(self):
435         """
436         Iterate only the leaf nodes in the tree.
437
438         >>> t = BinarySearchTree()
439         >>> t.insert(50)
440         >>> t.insert(75)
441         >>> t.insert(25)
442         >>> t.insert(66)
443         >>> t.insert(22)
444         >>> t.insert(13)
445
446         >>> for value in t.iterate_leaves():
447         ...     print(value)
448         13
449         66
450
451         """
452         if self.root is not None:
453             yield from self._iterate_leaves(self.root)
454
455     def _iterate_by_depth(self, node: Node, depth: int):
456         if depth == 0:
457             yield node.value
458         else:
459             assert depth > 0
460             if node.left is not None:
461                 yield from self._iterate_by_depth(node.left, depth - 1)
462             if node.right is not None:
463                 yield from self._iterate_by_depth(node.right, depth - 1)
464
465     def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
466         """
467         Iterate only the leaf nodes in the tree.
468
469         >>> t = BinarySearchTree()
470         >>> t.insert(50)
471         >>> t.insert(75)
472         >>> t.insert(25)
473         >>> t.insert(66)
474         >>> t.insert(22)
475         >>> t.insert(13)
476
477         >>> for value in t.iterate_nodes_by_depth(2):
478         ...     print(value)
479         22
480         66
481
482         >>> for value in t.iterate_nodes_by_depth(3):
483         ...     print(value)
484         13
485
486         """
487         if self.root is not None:
488             yield from self._iterate_by_depth(self.root, depth)
489
490     def get_next_node(self, node: Node) -> Node:
491         """
492         Given a tree node, get the next greater node in the tree.
493
494         >>> t = BinarySearchTree()
495         >>> t.insert(50)
496         >>> t.insert(75)
497         >>> t.insert(25)
498         >>> t.insert(66)
499         >>> t.insert(22)
500         >>> t.insert(13)
501         >>> t.insert(23)
502         >>> t
503         50
504         ├──25
505         │  └──22
506         │     ├──13
507         │     └──23
508         └──75
509            └──66
510
511         >>> n = t[23]
512         >>> t.get_next_node(n).value
513         25
514
515         >>> n = t[50]
516         >>> t.get_next_node(n).value
517         66
518
519         """
520         if node.right is not None:
521             x = node.right
522             while x.left is not None:
523                 x = x.left
524             return x
525
526         path = self.parent_path(node)
527         assert path[-1] is not None
528         assert path[-1] == node
529         path = path[:-1]
530         path.reverse()
531         for ancestor in path:
532             assert ancestor is not None
533             if node != ancestor.right:
534                 return ancestor
535             node = ancestor
536         raise Exception()
537
538     def _depth(self, node: Node, sofar: int) -> int:
539         depth_left = sofar + 1
540         depth_right = sofar + 1
541         if node.left is not None:
542             depth_left = self._depth(node.left, sofar + 1)
543         if node.right is not None:
544             depth_right = self._depth(node.right, sofar + 1)
545         return max(depth_left, depth_right)
546
547     def depth(self):
548         """
549         Returns the max height (depth) of the tree in plies (edge distance
550         from root).
551
552         >>> t = BinarySearchTree()
553         >>> t.depth()
554         0
555
556         >>> t.insert(50)
557         >>> t.depth()
558         1
559
560         >>> t.insert(65)
561         >>> t.depth()
562         2
563
564         >>> t.insert(33)
565         >>> t.depth()
566         2
567
568         >>> t.insert(2)
569         >>> t.insert(1)
570         >>> t.depth()
571         4
572
573         """
574         if self.root is None:
575             return 0
576         return self._depth(self.root, 0)
577
578     def height(self):
579         return self.depth()
580
581     def repr_traverse(
582         self,
583         padding: str,
584         pointer: str,
585         node: Optional[Node],
586         has_right_sibling: bool,
587     ) -> str:
588         if node is not None:
589             viz = f'\n{padding}{pointer}{node.value}'
590             if has_right_sibling:
591                 padding += "│  "
592             else:
593                 padding += '   '
594
595             pointer_right = "└──"
596             if node.right is not None:
597                 pointer_left = "├──"
598             else:
599                 pointer_left = "└──"
600
601             viz += self.repr_traverse(
602                 padding, pointer_left, node.left, node.right is not None
603             )
604             viz += self.repr_traverse(padding, pointer_right, node.right, False)
605             return viz
606         return ""
607
608     def __repr__(self):
609         """
610         Draw the tree in ASCII.
611
612         >>> t = BinarySearchTree()
613         >>> t.insert(50)
614         >>> t.insert(25)
615         >>> t.insert(75)
616         >>> t.insert(12)
617         >>> t.insert(33)
618         >>> t.insert(88)
619         >>> t.insert(55)
620         >>> t
621         50
622         ├──25
623         │  ├──12
624         │  └──33
625         └──75
626            ├──55
627            └──88
628         """
629         if self.root is None:
630             return ""
631
632         ret = f'{self.root.value}'
633         pointer_right = "└──"
634         if self.root.right is None:
635             pointer_left = "└──"
636         else:
637             pointer_left = "├──"
638
639         ret += self.repr_traverse(
640             '', pointer_left, self.root.left, self.root.left is not None
641         )
642         ret += self.repr_traverse('', pointer_right, self.root.right, False)
643         return ret