Adds shuffle/scramble to list_utils.
[python_utils.git] / collect / bst.py
1 #!/usr/bin/env python3
2
3 from typing import Any, Generator, List, Optional
4
5
6 class Node(object):
7     def __init__(self, value: Any) -> None:
8         """
9         Note: value can be anything as long as it is comparable.
10         Check out @functools.total_ordering.
11         """
12         self.left: Optional[Node] = None
13         self.right: Optional[Node] = None
14         self.value = value
15
16
17 class BinarySearchTree(object):
18     def __init__(self):
19         self.root = None
20         self.count = 0
21         self.traverse = None
22
23     def get_root(self) -> Optional[Node]:
24         return self.root
25
26     def insert(self, value: Any):
27         """
28         Insert something into the tree.
29
30         >>> t = BinarySearchTree()
31         >>> t.insert(10)
32         >>> t.insert(20)
33         >>> t.insert(5)
34         >>> len(t)
35         3
36
37         >>> t.get_root().value
38         10
39
40         """
41         if self.root is None:
42             self.root = Node(value)
43             self.count = 1
44         else:
45             self._insert(value, self.root)
46
47     def _insert(self, value: Any, node: Node):
48         """Insertion helper"""
49         if value < node.value:
50             if node.left is not None:
51                 self._insert(value, node.left)
52             else:
53                 node.left = Node(value)
54                 self.count += 1
55         else:
56             if node.right is not None:
57                 self._insert(value, node.right)
58             else:
59                 node.right = Node(value)
60                 self.count += 1
61
62     def __getitem__(self, value: Any) -> Optional[Node]:
63         """
64         Find an item in the tree and return its Node.  Returns
65         None if the item is not in the tree.
66
67         >>> t = BinarySearchTree()
68         >>> t[99]
69
70         >>> t.insert(10)
71         >>> t.insert(20)
72         >>> t.insert(5)
73         >>> t[10].value
74         10
75
76         >>> t[99]
77
78         """
79         if self.root is not None:
80             return self._find(value, self.root)
81         return None
82
83     def _find(self, value: Any, node: Node) -> Optional[Node]:
84         """Find helper"""
85         if value == node.value:
86             return node
87         elif value < node.value and node.left is not None:
88             return self._find(value, node.left)
89         elif value > node.value and node.right is not None:
90             return self._find(value, node.right)
91         return None
92
93     def _parent_path(self, current: Optional[Node], target: Node) -> List[Optional[Node]]:
94         if current is None:
95             return [None]
96         ret: List[Optional[Node]] = [current]
97         if target.value == current.value:
98             return ret
99         elif target.value < current.value:
100             ret.extend(self._parent_path(current.left, target))
101             return ret
102         else:
103             assert target.value > current.value
104             ret.extend(self._parent_path(current.right, target))
105             return ret
106
107     def parent_path(self, node: Node) -> List[Optional[Node]]:
108         """Return a list of nodes representing the path from
109         the tree's root to the node argument.  If the node does
110         not exist in the tree for some reason, the last element
111         on the path will be None but the path will indicate the
112         ancestor path of that node were it inserted.
113
114         >>> t = BinarySearchTree()
115         >>> t.insert(50)
116         >>> t.insert(75)
117         >>> t.insert(25)
118         >>> t.insert(12)
119         >>> t.insert(33)
120         >>> t.insert(4)
121         >>> t.insert(88)
122         >>> t
123         50
124         ├──25
125         │  ├──12
126         │  │  └──4
127         │  └──33
128         └──75
129            └──88
130
131         >>> n = t[4]
132         >>> for x in t.parent_path(n):
133         ...     print(x.value)
134         50
135         25
136         12
137         4
138
139         >>> del t[4]
140         >>> for x in t.parent_path(n):
141         ...     if x is not None:
142         ...         print(x.value)
143         ...     else:
144         ...         print(x)
145         50
146         25
147         12
148         None
149
150         """
151         return self._parent_path(self.root, node)
152
153     def __delitem__(self, value: Any) -> bool:
154         """
155         Delete an item from the tree and preserve the BST property.
156
157         >>> t = BinarySearchTree()
158         >>> t.insert(50)
159         >>> t.insert(75)
160         >>> t.insert(25)
161         >>> t.insert(66)
162         >>> t.insert(22)
163         >>> t.insert(13)
164         >>> t.insert(85)
165         >>> t
166         50
167         ├──25
168         │  └──22
169         │     └──13
170         └──75
171            ├──66
172            └──85
173
174         >>> for value in t.iterate_inorder():
175         ...     print(value)
176         13
177         22
178         25
179         50
180         66
181         75
182         85
183
184         >>> del t[22]  # Note: bool result is discarded
185
186         >>> for value in t.iterate_inorder():
187         ...     print(value)
188         13
189         25
190         50
191         66
192         75
193         85
194
195         >>> t.__delitem__(13)
196         True
197         >>> for value in t.iterate_inorder():
198         ...     print(value)
199         25
200         50
201         66
202         75
203         85
204
205         >>> t.__delitem__(75)
206         True
207         >>> for value in t.iterate_inorder():
208         ...     print(value)
209         25
210         50
211         66
212         85
213         >>> t
214         50
215         ├──25
216         └──85
217            └──66
218
219         >>> t.__delitem__(99)
220         False
221
222         """
223         if self.root is not None:
224             ret = self._delete(value, None, self.root)
225             if ret:
226                 self.count -= 1
227                 if self.count == 0:
228                     self.root = None
229             return ret
230         return False
231
232     def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool:
233         """Delete helper"""
234         if node.value == value:
235             # Deleting a leaf node
236             if node.left is None and node.right is None:
237                 if parent is not None:
238                     if parent.left == node:
239                         parent.left = None
240                     else:
241                         assert parent.right == node
242                         parent.right = None
243                 return True
244
245             # Node only has a right.
246             elif node.left is None:
247                 assert node.right is not None
248                 if parent is not None:
249                     if parent.left == node:
250                         parent.left = node.right
251                     else:
252                         assert parent.right == node
253                         parent.right = node.right
254                 return True
255
256             # Node only has a left.
257             elif node.right is None:
258                 assert node.left is not None
259                 if parent is not None:
260                     if parent.left == node:
261                         parent.left = node.left
262                     else:
263                         assert parent.right == node
264                         parent.right = node.left
265                 return True
266
267             # Node has both a left and right.
268             else:
269                 assert node.left is not None and node.right is not None
270                 descendent = node.right
271                 while descendent.left is not None:
272                     descendent = descendent.left
273                 node.value = descendent.value
274                 return self._delete(node.value, node, node.right)
275         elif value < node.value and node.left is not None:
276             return self._delete(value, node, node.left)
277         elif value > node.value and node.right is not None:
278             return self._delete(value, node, node.right)
279         return False
280
281     def __len__(self):
282         """
283         Returns the count of items in the tree.
284
285         >>> t = BinarySearchTree()
286         >>> len(t)
287         0
288         >>> t.insert(50)
289         >>> len(t)
290         1
291         >>> t.__delitem__(50)
292         True
293         >>> len(t)
294         0
295         >>> t.insert(75)
296         >>> t.insert(25)
297         >>> t.insert(66)
298         >>> t.insert(22)
299         >>> t.insert(13)
300         >>> t.insert(85)
301         >>> len(t)
302         6
303
304         """
305         return self.count
306
307     def __contains__(self, value: Any) -> bool:
308         """
309         Returns True if the item is in the tree; False otherwise.
310
311         """
312         return self.__getitem__(value) is not None
313
314     def _iterate_preorder(self, node: Node):
315         yield node.value
316         if node.left is not None:
317             yield from self._iterate_preorder(node.left)
318         if node.right is not None:
319             yield from self._iterate_preorder(node.right)
320
321     def _iterate_inorder(self, node: Node):
322         if node.left is not None:
323             yield from self._iterate_inorder(node.left)
324         yield node.value
325         if node.right is not None:
326             yield from self._iterate_inorder(node.right)
327
328     def _iterate_postorder(self, node: Node):
329         if node.left is not None:
330             yield from self._iterate_postorder(node.left)
331         if node.right is not None:
332             yield from self._iterate_postorder(node.right)
333         yield node.value
334
335     def iterate_preorder(self):
336         """
337         Yield the tree's items in a preorder traversal sequence.
338
339         >>> t = BinarySearchTree()
340         >>> t.insert(50)
341         >>> t.insert(75)
342         >>> t.insert(25)
343         >>> t.insert(66)
344         >>> t.insert(22)
345         >>> t.insert(13)
346
347         >>> for value in t.iterate_preorder():
348         ...     print(value)
349         50
350         25
351         22
352         13
353         75
354         66
355
356         """
357         if self.root is not None:
358             yield from self._iterate_preorder(self.root)
359
360     def iterate_inorder(self):
361         """
362         Yield the tree's items in a preorder traversal sequence.
363
364         >>> t = BinarySearchTree()
365         >>> t.insert(50)
366         >>> t.insert(75)
367         >>> t.insert(25)
368         >>> t.insert(66)
369         >>> t.insert(22)
370         >>> t.insert(13)
371         >>> t.insert(24)
372         >>> t
373         50
374         ├──25
375         │  └──22
376         │     ├──13
377         │     └──24
378         └──75
379            └──66
380
381         >>> for value in t.iterate_inorder():
382         ...     print(value)
383         13
384         22
385         24
386         25
387         50
388         66
389         75
390
391         """
392         if self.root is not None:
393             yield from self._iterate_inorder(self.root)
394
395     def iterate_postorder(self):
396         """
397         Yield the tree's items in a preorder traversal sequence.
398
399         >>> t = BinarySearchTree()
400         >>> t.insert(50)
401         >>> t.insert(75)
402         >>> t.insert(25)
403         >>> t.insert(66)
404         >>> t.insert(22)
405         >>> t.insert(13)
406
407         >>> for value in t.iterate_postorder():
408         ...     print(value)
409         13
410         22
411         25
412         66
413         75
414         50
415
416         """
417         if self.root is not None:
418             yield from self._iterate_postorder(self.root)
419
420     def _iterate_leaves(self, node: Node):
421         if node.left is not None:
422             yield from self._iterate_leaves(node.left)
423         if node.right is not None:
424             yield from self._iterate_leaves(node.right)
425         if node.left is None and node.right is None:
426             yield node.value
427
428     def iterate_leaves(self):
429         """
430         Iterate only the leaf nodes in the tree.
431
432         >>> t = BinarySearchTree()
433         >>> t.insert(50)
434         >>> t.insert(75)
435         >>> t.insert(25)
436         >>> t.insert(66)
437         >>> t.insert(22)
438         >>> t.insert(13)
439
440         >>> for value in t.iterate_leaves():
441         ...     print(value)
442         13
443         66
444
445         """
446         if self.root is not None:
447             yield from self._iterate_leaves(self.root)
448
449     def _iterate_by_depth(self, node: Node, depth: int):
450         if depth == 0:
451             yield node.value
452         else:
453             assert depth > 0
454             if node.left is not None:
455                 yield from self._iterate_by_depth(node.left, depth - 1)
456             if node.right is not None:
457                 yield from self._iterate_by_depth(node.right, depth - 1)
458
459     def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
460         """
461         Iterate only the leaf nodes in the tree.
462
463         >>> t = BinarySearchTree()
464         >>> t.insert(50)
465         >>> t.insert(75)
466         >>> t.insert(25)
467         >>> t.insert(66)
468         >>> t.insert(22)
469         >>> t.insert(13)
470
471         >>> for value in t.iterate_nodes_by_depth(2):
472         ...     print(value)
473         22
474         66
475
476         >>> for value in t.iterate_nodes_by_depth(3):
477         ...     print(value)
478         13
479
480         """
481         if self.root is not None:
482             yield from self._iterate_by_depth(self.root, depth)
483
484     def get_next_node(self, node: Node) -> Node:
485         """
486         Given a tree node, get the next greater node in the tree.
487
488         >>> t = BinarySearchTree()
489         >>> t.insert(50)
490         >>> t.insert(75)
491         >>> t.insert(25)
492         >>> t.insert(66)
493         >>> t.insert(22)
494         >>> t.insert(13)
495         >>> t.insert(23)
496         >>> t
497         50
498         ├──25
499         │  └──22
500         │     ├──13
501         │     └──23
502         └──75
503            └──66
504
505         >>> n = t[23]
506         >>> t.get_next_node(n).value
507         25
508
509         >>> n = t[50]
510         >>> t.get_next_node(n).value
511         66
512
513         """
514         if node.right is not None:
515             x = node.right
516             while x.left is not None:
517                 x = x.left
518             return x
519
520         path = self.parent_path(node)
521         assert path[-1] is not None
522         assert path[-1] == node
523         path = path[:-1]
524         path.reverse()
525         for ancestor in path:
526             assert ancestor is not None
527             if node != ancestor.right:
528                 return ancestor
529             node = ancestor
530         raise Exception()
531
532     def _depth(self, node: Node, sofar: int) -> int:
533         depth_left = sofar + 1
534         depth_right = sofar + 1
535         if node.left is not None:
536             depth_left = self._depth(node.left, sofar + 1)
537         if node.right is not None:
538             depth_right = self._depth(node.right, sofar + 1)
539         return max(depth_left, depth_right)
540
541     def depth(self):
542         """
543         Returns the max height (depth) of the tree in plies (edge distance
544         from root).
545
546         >>> t = BinarySearchTree()
547         >>> t.depth()
548         0
549
550         >>> t.insert(50)
551         >>> t.depth()
552         1
553
554         >>> t.insert(65)
555         >>> t.depth()
556         2
557
558         >>> t.insert(33)
559         >>> t.depth()
560         2
561
562         >>> t.insert(2)
563         >>> t.insert(1)
564         >>> t.depth()
565         4
566
567         """
568         if self.root is None:
569             return 0
570         return self._depth(self.root, 0)
571
572     def height(self):
573         return self.depth()
574
575     def repr_traverse(
576         self,
577         padding: str,
578         pointer: str,
579         node: Optional[Node],
580         has_right_sibling: bool,
581     ) -> str:
582         if node is not None:
583             viz = f'\n{padding}{pointer}{node.value}'
584             if has_right_sibling:
585                 padding += "│  "
586             else:
587                 padding += '   '
588
589             pointer_right = "└──"
590             if node.right is not None:
591                 pointer_left = "├──"
592             else:
593                 pointer_left = "└──"
594
595             viz += self.repr_traverse(padding, pointer_left, node.left, node.right is not None)
596             viz += self.repr_traverse(padding, pointer_right, node.right, False)
597             return viz
598         return ""
599
600     def __repr__(self):
601         """
602         Draw the tree in ASCII.
603
604         >>> t = BinarySearchTree()
605         >>> t.insert(50)
606         >>> t.insert(25)
607         >>> t.insert(75)
608         >>> t.insert(12)
609         >>> t.insert(33)
610         >>> t.insert(88)
611         >>> t.insert(55)
612         >>> t
613         50
614         ├──25
615         │  ├──12
616         │  └──33
617         └──75
618            ├──55
619            └──88
620         """
621         if self.root is None:
622             return ""
623
624         ret = f'{self.root.value}'
625         pointer_right = "└──"
626         if self.root.right is None:
627             pointer_left = "└──"
628         else:
629             pointer_left = "├──"
630
631         ret += self.repr_traverse('', pointer_left, self.root.left, self.root.left is not None)
632         ret += self.repr_traverse('', pointer_right, self.root.right, False)
633         return ret
634
635
636 if __name__ == '__main__':
637     import doctest
638
639     doctest.testmod()