Since this thing is on the innerwebs I suppose it should have a
[python_utils.git] / collect / bst.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """Binary search tree."""
6
7 from typing import Any, Generator, List, Optional
8
9
10 class Node(object):
11     def __init__(self, value: Any) -> None:
12         """
13         Note: value can be anything as long as it is comparable.
14         Check out @functools.total_ordering.
15         """
16         self.left: Optional[Node] = None
17         self.right: Optional[Node] = None
18         self.value = value
19
20
21 class BinarySearchTree(object):
22     def __init__(self):
23         self.root = None
24         self.count = 0
25         self.traverse = None
26
27     def get_root(self) -> Optional[Node]:
28         return self.root
29
30     def insert(self, value: Any):
31         """
32         Insert something into the tree.
33
34         >>> t = BinarySearchTree()
35         >>> t.insert(10)
36         >>> t.insert(20)
37         >>> t.insert(5)
38         >>> len(t)
39         3
40
41         >>> t.get_root().value
42         10
43
44         """
45         if self.root is None:
46             self.root = Node(value)
47             self.count = 1
48         else:
49             self._insert(value, self.root)
50
51     def _insert(self, value: Any, node: Node):
52         """Insertion helper"""
53         if value < node.value:
54             if node.left is not None:
55                 self._insert(value, node.left)
56             else:
57                 node.left = Node(value)
58                 self.count += 1
59         else:
60             if node.right is not None:
61                 self._insert(value, node.right)
62             else:
63                 node.right = Node(value)
64                 self.count += 1
65
66     def __getitem__(self, value: Any) -> Optional[Node]:
67         """
68         Find an item in the tree and return its Node.  Returns
69         None if the item is not in the tree.
70
71         >>> t = BinarySearchTree()
72         >>> t[99]
73
74         >>> t.insert(10)
75         >>> t.insert(20)
76         >>> t.insert(5)
77         >>> t[10].value
78         10
79
80         >>> t[99]
81
82         """
83         if self.root is not None:
84             return self._find(value, self.root)
85         return None
86
87     def _find(self, value: Any, node: Node) -> Optional[Node]:
88         """Find helper"""
89         if value == node.value:
90             return node
91         elif value < node.value and node.left is not None:
92             return self._find(value, node.left)
93         elif value > node.value and node.right is not None:
94             return self._find(value, node.right)
95         return None
96
97     def _parent_path(self, current: Optional[Node], target: Node) -> List[Optional[Node]]:
98         if current is None:
99             return [None]
100         ret: List[Optional[Node]] = [current]
101         if target.value == current.value:
102             return ret
103         elif target.value < current.value:
104             ret.extend(self._parent_path(current.left, target))
105             return ret
106         else:
107             assert target.value > current.value
108             ret.extend(self._parent_path(current.right, target))
109             return ret
110
111     def parent_path(self, node: Node) -> List[Optional[Node]]:
112         """Return a list of nodes representing the path from
113         the tree's root to the node argument.  If the node does
114         not exist in the tree for some reason, the last element
115         on the path will be None but the path will indicate the
116         ancestor path of that node were it inserted.
117
118         >>> t = BinarySearchTree()
119         >>> t.insert(50)
120         >>> t.insert(75)
121         >>> t.insert(25)
122         >>> t.insert(12)
123         >>> t.insert(33)
124         >>> t.insert(4)
125         >>> t.insert(88)
126         >>> t
127         50
128         ├──25
129         │  ├──12
130         │  │  └──4
131         │  └──33
132         └──75
133            └──88
134
135         >>> n = t[4]
136         >>> for x in t.parent_path(n):
137         ...     print(x.value)
138         50
139         25
140         12
141         4
142
143         >>> del t[4]
144         >>> for x in t.parent_path(n):
145         ...     if x is not None:
146         ...         print(x.value)
147         ...     else:
148         ...         print(x)
149         50
150         25
151         12
152         None
153
154         """
155         return self._parent_path(self.root, node)
156
157     def __delitem__(self, value: Any) -> bool:
158         """
159         Delete an item from the tree and preserve the BST property.
160
161         >>> t = BinarySearchTree()
162         >>> t.insert(50)
163         >>> t.insert(75)
164         >>> t.insert(25)
165         >>> t.insert(66)
166         >>> t.insert(22)
167         >>> t.insert(13)
168         >>> t.insert(85)
169         >>> t
170         50
171         ├──25
172         │  └──22
173         │     └──13
174         └──75
175            ├──66
176            └──85
177
178         >>> for value in t.iterate_inorder():
179         ...     print(value)
180         13
181         22
182         25
183         50
184         66
185         75
186         85
187
188         >>> del t[22]  # Note: bool result is discarded
189
190         >>> for value in t.iterate_inorder():
191         ...     print(value)
192         13
193         25
194         50
195         66
196         75
197         85
198
199         >>> t.__delitem__(13)
200         True
201         >>> for value in t.iterate_inorder():
202         ...     print(value)
203         25
204         50
205         66
206         75
207         85
208
209         >>> t.__delitem__(75)
210         True
211         >>> for value in t.iterate_inorder():
212         ...     print(value)
213         25
214         50
215         66
216         85
217         >>> t
218         50
219         ├──25
220         └──85
221            └──66
222
223         >>> t.__delitem__(99)
224         False
225
226         """
227         if self.root is not None:
228             ret = self._delete(value, None, self.root)
229             if ret:
230                 self.count -= 1
231                 if self.count == 0:
232                     self.root = None
233             return ret
234         return False
235
236     def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool:
237         """Delete helper"""
238         if node.value == value:
239             # Deleting a leaf node
240             if node.left is None and node.right is None:
241                 if parent is not None:
242                     if parent.left == node:
243                         parent.left = None
244                     else:
245                         assert parent.right == node
246                         parent.right = None
247                 return True
248
249             # Node only has a right.
250             elif node.left is None:
251                 assert node.right is not None
252                 if parent is not None:
253                     if parent.left == node:
254                         parent.left = node.right
255                     else:
256                         assert parent.right == node
257                         parent.right = node.right
258                 return True
259
260             # Node only has a left.
261             elif node.right is None:
262                 assert node.left is not None
263                 if parent is not None:
264                     if parent.left == node:
265                         parent.left = node.left
266                     else:
267                         assert parent.right == node
268                         parent.right = node.left
269                 return True
270
271             # Node has both a left and right.
272             else:
273                 assert node.left is not None and node.right is not None
274                 descendent = node.right
275                 while descendent.left is not None:
276                     descendent = descendent.left
277                 node.value = descendent.value
278                 return self._delete(node.value, node, node.right)
279         elif value < node.value and node.left is not None:
280             return self._delete(value, node, node.left)
281         elif value > node.value and node.right is not None:
282             return self._delete(value, node, node.right)
283         return False
284
285     def __len__(self):
286         """
287         Returns the count of items in the tree.
288
289         >>> t = BinarySearchTree()
290         >>> len(t)
291         0
292         >>> t.insert(50)
293         >>> len(t)
294         1
295         >>> t.__delitem__(50)
296         True
297         >>> len(t)
298         0
299         >>> t.insert(75)
300         >>> t.insert(25)
301         >>> t.insert(66)
302         >>> t.insert(22)
303         >>> t.insert(13)
304         >>> t.insert(85)
305         >>> len(t)
306         6
307
308         """
309         return self.count
310
311     def __contains__(self, value: Any) -> bool:
312         """
313         Returns True if the item is in the tree; False otherwise.
314
315         """
316         return self.__getitem__(value) is not None
317
318     def _iterate_preorder(self, node: Node):
319         yield node.value
320         if node.left is not None:
321             yield from self._iterate_preorder(node.left)
322         if node.right is not None:
323             yield from self._iterate_preorder(node.right)
324
325     def _iterate_inorder(self, node: Node):
326         if node.left is not None:
327             yield from self._iterate_inorder(node.left)
328         yield node.value
329         if node.right is not None:
330             yield from self._iterate_inorder(node.right)
331
332     def _iterate_postorder(self, node: Node):
333         if node.left is not None:
334             yield from self._iterate_postorder(node.left)
335         if node.right is not None:
336             yield from self._iterate_postorder(node.right)
337         yield node.value
338
339     def iterate_preorder(self):
340         """
341         Yield the tree's items in a preorder traversal sequence.
342
343         >>> t = BinarySearchTree()
344         >>> t.insert(50)
345         >>> t.insert(75)
346         >>> t.insert(25)
347         >>> t.insert(66)
348         >>> t.insert(22)
349         >>> t.insert(13)
350
351         >>> for value in t.iterate_preorder():
352         ...     print(value)
353         50
354         25
355         22
356         13
357         75
358         66
359
360         """
361         if self.root is not None:
362             yield from self._iterate_preorder(self.root)
363
364     def iterate_inorder(self):
365         """
366         Yield the tree's items in a preorder traversal sequence.
367
368         >>> t = BinarySearchTree()
369         >>> t.insert(50)
370         >>> t.insert(75)
371         >>> t.insert(25)
372         >>> t.insert(66)
373         >>> t.insert(22)
374         >>> t.insert(13)
375         >>> t.insert(24)
376         >>> t
377         50
378         ├──25
379         │  └──22
380         │     ├──13
381         │     └──24
382         └──75
383            └──66
384
385         >>> for value in t.iterate_inorder():
386         ...     print(value)
387         13
388         22
389         24
390         25
391         50
392         66
393         75
394
395         """
396         if self.root is not None:
397             yield from self._iterate_inorder(self.root)
398
399     def iterate_postorder(self):
400         """
401         Yield the tree's items in a preorder traversal sequence.
402
403         >>> t = BinarySearchTree()
404         >>> t.insert(50)
405         >>> t.insert(75)
406         >>> t.insert(25)
407         >>> t.insert(66)
408         >>> t.insert(22)
409         >>> t.insert(13)
410
411         >>> for value in t.iterate_postorder():
412         ...     print(value)
413         13
414         22
415         25
416         66
417         75
418         50
419
420         """
421         if self.root is not None:
422             yield from self._iterate_postorder(self.root)
423
424     def _iterate_leaves(self, node: Node):
425         if node.left is not None:
426             yield from self._iterate_leaves(node.left)
427         if node.right is not None:
428             yield from self._iterate_leaves(node.right)
429         if node.left is None and node.right is None:
430             yield node.value
431
432     def iterate_leaves(self):
433         """
434         Iterate only the leaf nodes in the tree.
435
436         >>> t = BinarySearchTree()
437         >>> t.insert(50)
438         >>> t.insert(75)
439         >>> t.insert(25)
440         >>> t.insert(66)
441         >>> t.insert(22)
442         >>> t.insert(13)
443
444         >>> for value in t.iterate_leaves():
445         ...     print(value)
446         13
447         66
448
449         """
450         if self.root is not None:
451             yield from self._iterate_leaves(self.root)
452
453     def _iterate_by_depth(self, node: Node, depth: int):
454         if depth == 0:
455             yield node.value
456         else:
457             assert depth > 0
458             if node.left is not None:
459                 yield from self._iterate_by_depth(node.left, depth - 1)
460             if node.right is not None:
461                 yield from self._iterate_by_depth(node.right, depth - 1)
462
463     def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
464         """
465         Iterate only the leaf nodes in the tree.
466
467         >>> t = BinarySearchTree()
468         >>> t.insert(50)
469         >>> t.insert(75)
470         >>> t.insert(25)
471         >>> t.insert(66)
472         >>> t.insert(22)
473         >>> t.insert(13)
474
475         >>> for value in t.iterate_nodes_by_depth(2):
476         ...     print(value)
477         22
478         66
479
480         >>> for value in t.iterate_nodes_by_depth(3):
481         ...     print(value)
482         13
483
484         """
485         if self.root is not None:
486             yield from self._iterate_by_depth(self.root, depth)
487
488     def get_next_node(self, node: Node) -> Node:
489         """
490         Given a tree node, get the next greater node in the tree.
491
492         >>> t = BinarySearchTree()
493         >>> t.insert(50)
494         >>> t.insert(75)
495         >>> t.insert(25)
496         >>> t.insert(66)
497         >>> t.insert(22)
498         >>> t.insert(13)
499         >>> t.insert(23)
500         >>> t
501         50
502         ├──25
503         │  └──22
504         │     ├──13
505         │     └──23
506         └──75
507            └──66
508
509         >>> n = t[23]
510         >>> t.get_next_node(n).value
511         25
512
513         >>> n = t[50]
514         >>> t.get_next_node(n).value
515         66
516
517         """
518         if node.right is not None:
519             x = node.right
520             while x.left is not None:
521                 x = x.left
522             return x
523
524         path = self.parent_path(node)
525         assert path[-1] is not None
526         assert path[-1] == node
527         path = path[:-1]
528         path.reverse()
529         for ancestor in path:
530             assert ancestor is not None
531             if node != ancestor.right:
532                 return ancestor
533             node = ancestor
534         raise Exception()
535
536     def _depth(self, node: Node, sofar: int) -> int:
537         depth_left = sofar + 1
538         depth_right = sofar + 1
539         if node.left is not None:
540             depth_left = self._depth(node.left, sofar + 1)
541         if node.right is not None:
542             depth_right = self._depth(node.right, sofar + 1)
543         return max(depth_left, depth_right)
544
545     def depth(self):
546         """
547         Returns the max height (depth) of the tree in plies (edge distance
548         from root).
549
550         >>> t = BinarySearchTree()
551         >>> t.depth()
552         0
553
554         >>> t.insert(50)
555         >>> t.depth()
556         1
557
558         >>> t.insert(65)
559         >>> t.depth()
560         2
561
562         >>> t.insert(33)
563         >>> t.depth()
564         2
565
566         >>> t.insert(2)
567         >>> t.insert(1)
568         >>> t.depth()
569         4
570
571         """
572         if self.root is None:
573             return 0
574         return self._depth(self.root, 0)
575
576     def height(self):
577         return self.depth()
578
579     def repr_traverse(
580         self,
581         padding: str,
582         pointer: str,
583         node: Optional[Node],
584         has_right_sibling: bool,
585     ) -> str:
586         if node is not None:
587             viz = f'\n{padding}{pointer}{node.value}'
588             if has_right_sibling:
589                 padding += "│  "
590             else:
591                 padding += '   '
592
593             pointer_right = "└──"
594             if node.right is not None:
595                 pointer_left = "├──"
596             else:
597                 pointer_left = "└──"
598
599             viz += self.repr_traverse(padding, pointer_left, node.left, node.right is not None)
600             viz += self.repr_traverse(padding, pointer_right, node.right, False)
601             return viz
602         return ""
603
604     def __repr__(self):
605         """
606         Draw the tree in ASCII.
607
608         >>> t = BinarySearchTree()
609         >>> t.insert(50)
610         >>> t.insert(25)
611         >>> t.insert(75)
612         >>> t.insert(12)
613         >>> t.insert(33)
614         >>> t.insert(88)
615         >>> t.insert(55)
616         >>> t
617         50
618         ├──25
619         │  ├──12
620         │  └──33
621         └──75
622            ├──55
623            └──88
624         """
625         if self.root is None:
626             return ""
627
628         ret = f'{self.root.value}'
629         pointer_right = "└──"
630         if self.root.right is None:
631             pointer_left = "└──"
632         else:
633             pointer_left = "├──"
634
635         ret += self.repr_traverse('', pointer_left, self.root.left, self.root.left is not None)
636         ret += self.repr_traverse('', pointer_right, self.root.right, False)
637         return ret
638
639
640 if __name__ == '__main__':
641     import doctest
642
643     doctest.testmod()