changes
[python_utils.git] / collect / bst.py
1 #!/usr/bin/env python3
2
3 from typing import Any, Optional, List
4
5
6 class Node(object):
7     def __init__(self, value: Any) -> None:
8         """
9         Note: value can be anything as long as it is comparable.
10         Check out @functools.total_ordering.
11         """
12         self.left = None
13         self.right = None
14         self.value = value
15
16
17 class BinarySearchTree(object):
18     def __init__(self):
19         self.root = None
20         self.count = 0
21         self.traverse = None
22
23     def get_root(self) -> Optional[Node]:
24         return self.root
25
26     def insert(self, value: Any):
27         """
28         Insert something into the tree.
29
30         >>> t = BinarySearchTree()
31         >>> t.insert(10)
32         >>> t.insert(20)
33         >>> t.insert(5)
34         >>> len(t)
35         3
36
37         >>> t.get_root().value
38         10
39
40         """
41         if self.root is None:
42             self.root = Node(value)
43             self.count = 1
44         else:
45             self._insert(value, self.root)
46
47     def _insert(self, value: Any, node: Node):
48         """Insertion helper"""
49         if value < node.value:
50             if node.left is not None:
51                 self._insert(value, node.left)
52             else:
53                 node.left = Node(value)
54                 self.count += 1
55         else:
56             if node.right is not None:
57                 self._insert(value, node.right)
58             else:
59                 node.right = Node(value)
60                 self.count += 1
61
62     def __getitem__(self, value: Any) -> Optional[Node]:
63         """
64         Find an item in the tree and return its Node.  Returns
65         None if the item is not in the tree.
66
67         >>> t = BinarySearchTree()
68         >>> t[99]
69
70         >>> t.insert(10)
71         >>> t.insert(20)
72         >>> t.insert(5)
73         >>> t[10].value
74         10
75
76         >>> t[99]
77
78         """
79         if self.root is not None:
80             return self._find(value, self.root)
81         return None
82
83     def _find(self, value: Any, node: Node) -> Optional[Node]:
84         """Find helper"""
85         if value == node.value:
86             return node
87         elif (value < node.value and node.left is not None):
88             return self._find(value, node.left)
89         elif (value > node.value and node.right is not None):
90             return self._find(value, node.right)
91         return None
92
93     def _parent_path(self, current: Node, target: Node):
94         if current is None:
95             return [None]
96         ret = [current]
97         if target.value == current.value:
98             return ret
99         elif target.value < current.value:
100             ret.extend(self._parent_path(current.left, target))
101             return ret
102         else:
103             assert target.value > current.value
104             ret.extend(self._parent_path(current.right, target))
105             return ret
106
107     def parent_path(self, node: Node) -> Optional[List[Node]]:
108         """Return a list of nodes representing the path from
109         the tree's root to the node argument.  If the node does
110         not exist in the tree for some reason, the last element
111         on the path will be None but the path will indicate the
112         ancestor path of that node were it inserted.
113
114         >>> t = BinarySearchTree()
115         >>> t.insert(50)
116         >>> t.insert(75)
117         >>> t.insert(25)
118         >>> t.insert(12)
119         >>> t.insert(33)
120         >>> t.insert(4)
121         >>> t.insert(88)
122         >>> t
123         50
124         ├──25
125         │  ├──12
126         │  │  └──4
127         │  └──33
128         └──75
129            └──88
130
131         >>> n = t[4]
132         >>> for x in t.parent_path(n):
133         ...     print(x.value)
134         50
135         25
136         12
137         4
138
139         >>> del t[4]
140         >>> for x in t.parent_path(n):
141         ...     if x is not None:
142         ...         print(x.value)
143         ...     else:
144         ...         print(x)
145         50
146         25
147         12
148         None
149
150         """
151         return self._parent_path(self.root, node)
152
153     def __delitem__(self, value: Any) -> bool:
154         """
155         Delete an item from the tree and preserve the BST property.
156
157         >>> t = BinarySearchTree()
158         >>> t.insert(50)
159         >>> t.insert(75)
160         >>> t.insert(25)
161         >>> t.insert(66)
162         >>> t.insert(22)
163         >>> t.insert(13)
164         >>> t.insert(85)
165         >>> t
166         50
167         ├──25
168         │  └──22
169         │     └──13
170         └──75
171            ├──66
172            └──85
173
174         >>> for value in t.iterate_inorder():
175         ...     print(value)
176         13
177         22
178         25
179         50
180         66
181         75
182         85
183
184         >>> del t[22]  # Note: bool result is discarded
185
186         >>> for value in t.iterate_inorder():
187         ...     print(value)
188         13
189         25
190         50
191         66
192         75
193         85
194
195         >>> t.__delitem__(13)
196         True
197         >>> for value in t.iterate_inorder():
198         ...     print(value)
199         25
200         50
201         66
202         75
203         85
204
205         >>> t.__delitem__(75)
206         True
207         >>> for value in t.iterate_inorder():
208         ...     print(value)
209         25
210         50
211         66
212         85
213         >>> t
214         50
215         ├──25
216         └──85
217            └──66
218
219         >>> t.__delitem__(99)
220         False
221
222         """
223         if self.root is not None:
224             ret = self._delete(value, None, self.root)
225             if ret:
226                 self.count -= 1
227                 if self.count == 0:
228                     self.root = None
229             return ret
230         return False
231
232     def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool:
233         """Delete helper"""
234         if node.value == value:
235             # Deleting a leaf node
236             if node.left is None and node.right is None:
237                 if parent is not None:
238                     if parent.left == node:
239                         parent.left = None
240                     else:
241                         assert parent.right == node
242                         parent.right = None
243                 return True
244
245             # Node only has a right.
246             elif node.left is None:
247                 assert node.right is not None
248                 if parent is not None:
249                     if parent.left == node:
250                         parent.left = node.right
251                     else:
252                         assert parent.right == node
253                         parent.right = node.right
254                 return True
255
256             # Node only has a left.
257             elif node.right is None:
258                 assert node.left is not None
259                 if parent is not None:
260                     if parent.left == node:
261                         parent.left = node.left
262                     else:
263                         assert parent.right == node
264                         parent.right = node.left
265                 return True
266
267             # Node has both a left and right.
268             else:
269                 assert node.left is not None and node.right is not None
270                 descendent = node.right
271                 while descendent.left is not None:
272                     descendent = descendent.left
273                 node.value = descendent.value
274                 return self._delete(node.value, node, node.right)
275         elif value < node.value and node.left is not None:
276             return self._delete(value, node, node.left)
277         elif value > node.value and node.right is not None:
278             return self._delete(value, node, node.right)
279         return False
280
281     def __len__(self):
282         """
283         Returns the count of items in the tree.
284
285         >>> t = BinarySearchTree()
286         >>> len(t)
287         0
288         >>> t.insert(50)
289         >>> len(t)
290         1
291         >>> t.__delitem__(50)
292         True
293         >>> len(t)
294         0
295         >>> t.insert(75)
296         >>> t.insert(25)
297         >>> t.insert(66)
298         >>> t.insert(22)
299         >>> t.insert(13)
300         >>> t.insert(85)
301         >>> len(t)
302         6
303
304         """
305         return self.count
306
307     def __contains__(self, value: Any) -> bool:
308         """
309         Returns True if the item is in the tree; False otherwise.
310
311         """
312         return self.__getitem__(value) is not None
313
314     def _iterate_preorder(self, node: Node):
315         yield node.value
316         if node.left is not None:
317             yield from self._iterate_preorder(node.left)
318         if node.right is not None:
319             yield from self._iterate_preorder(node.right)
320
321     def _iterate_inorder(self, node: Node):
322         if node.left is not None:
323             yield from self._iterate_inorder(node.left)
324         yield node.value
325         if node.right is not None:
326             yield from self._iterate_inorder(node.right)
327
328     def _iterate_postorder(self, node: Node):
329         if node.left is not None:
330             yield from self._iterate_postorder(node.left)
331         if node.right is not None:
332             yield from self._iterate_postorder(node.right)
333         yield node.value
334
335     def iterate_preorder(self):
336         """
337         Yield the tree's items in a preorder traversal sequence.
338
339         >>> t = BinarySearchTree()
340         >>> t.insert(50)
341         >>> t.insert(75)
342         >>> t.insert(25)
343         >>> t.insert(66)
344         >>> t.insert(22)
345         >>> t.insert(13)
346
347         >>> for value in t.iterate_preorder():
348         ...     print(value)
349         50
350         25
351         22
352         13
353         75
354         66
355
356         """
357         if self.root is not None:
358             yield from self._iterate_preorder(self.root)
359
360     def iterate_inorder(self):
361         """
362         Yield the tree's items in a preorder traversal sequence.
363
364         >>> t = BinarySearchTree()
365         >>> t.insert(50)
366         >>> t.insert(75)
367         >>> t.insert(25)
368         >>> t.insert(66)
369         >>> t.insert(22)
370         >>> t.insert(13)
371         >>> t.insert(24)
372         >>> t
373         50
374         ├──25
375         │  └──22
376         │     ├──13
377         │     └──24
378         └──75
379            └──66
380
381         >>> for value in t.iterate_inorder():
382         ...     print(value)
383         13
384         22
385         24
386         25
387         50
388         66
389         75
390
391         """
392         if self.root is not None:
393             yield from self._iterate_inorder(self.root)
394
395     def iterate_postorder(self):
396         """
397         Yield the tree's items in a preorder traversal sequence.
398
399         >>> t = BinarySearchTree()
400         >>> t.insert(50)
401         >>> t.insert(75)
402         >>> t.insert(25)
403         >>> t.insert(66)
404         >>> t.insert(22)
405         >>> t.insert(13)
406
407         >>> for value in t.iterate_postorder():
408         ...     print(value)
409         13
410         22
411         25
412         66
413         75
414         50
415
416         """
417         if self.root is not None:
418             yield from self._iterate_postorder(self.root)
419
420     def _iterate_leaves(self, node: Node):
421         if node.left is not None:
422             yield from self._iterate_leaves(node.left)
423         if node.right is not None:
424             yield from self._iterate_leaves(node.right)
425         if node.left is None and node.right is None:
426             yield node.value
427
428     def iterate_leaves(self):
429         """
430         Iterate only the leaf nodes in the tree.
431
432         >>> t = BinarySearchTree()
433         >>> t.insert(50)
434         >>> t.insert(75)
435         >>> t.insert(25)
436         >>> t.insert(66)
437         >>> t.insert(22)
438         >>> t.insert(13)
439
440         >>> for value in t.iterate_leaves():
441         ...     print(value)
442         13
443         66
444
445         """
446         if self.root is not None:
447             yield from self._iterate_leaves(self.root)
448
449     def _iterate_by_depth(self, node: Node, depth: int):
450         if depth == 0:
451             yield node.value
452         else:
453             assert depth > 0
454             if node.left is not None:
455                 yield from self._iterate_by_depth(node.left, depth - 1)
456             if node.right is not None:
457                 yield from self._iterate_by_depth(node.right, depth - 1)
458
459     def iterate_nodes_by_depth(self, depth: int):
460         """
461         Iterate only the leaf nodes in the tree.
462
463         >>> t = BinarySearchTree()
464         >>> t.insert(50)
465         >>> t.insert(75)
466         >>> t.insert(25)
467         >>> t.insert(66)
468         >>> t.insert(22)
469         >>> t.insert(13)
470
471         >>> for value in t.iterate_nodes_by_depth(2):
472         ...     print(value)
473         22
474         66
475
476         >>> for value in t.iterate_nodes_by_depth(3):
477         ...     print(value)
478         13
479
480         """
481         if self.root is not None:
482             yield from self._iterate_by_depth(self.root, depth)
483
484     def get_next_node(self, node: Node) -> Node:
485         """
486         Given a tree node, get the next greater node in the tree.
487
488         >>> t = BinarySearchTree()
489         >>> t.insert(50)
490         >>> t.insert(75)
491         >>> t.insert(25)
492         >>> t.insert(66)
493         >>> t.insert(22)
494         >>> t.insert(13)
495         >>> t.insert(23)
496         >>> t
497         50
498         ├──25
499         │  └──22
500         │     ├──13
501         │     └──23
502         └──75
503            └──66
504
505         >>> n = t[23]
506         >>> t.get_next_node(n).value
507         25
508
509         >>> n = t[50]
510         >>> t.get_next_node(n).value
511         66
512
513         """
514         if node.right is not None:
515             x = node.right
516             while x.left is not None:
517                 x = x.left
518             return x
519
520         path = self.parent_path(node)
521         assert path[-1] == node
522         path = path[:-1]
523         path.reverse()
524         for ancestor in path:
525             if node != ancestor.right:
526                 return ancestor
527             node = ancestor
528
529     def _depth(self, node: Node, sofar: int) -> int:
530         depth_left = sofar + 1
531         depth_right = sofar + 1
532         if node.left is not None:
533             depth_left = self._depth(node.left, sofar + 1)
534         if node.right is not None:
535             depth_right = self._depth(node.right, sofar + 1)
536         return max(depth_left, depth_right)
537
538     def depth(self):
539         """
540         Returns the max height (depth) of the tree in plies (edge distance
541         from root).
542
543         >>> t = BinarySearchTree()
544         >>> t.depth()
545         0
546
547         >>> t.insert(50)
548         >>> t.depth()
549         1
550
551         >>> t.insert(65)
552         >>> t.depth()
553         2
554
555         >>> t.insert(33)
556         >>> t.depth()
557         2
558
559         >>> t.insert(2)
560         >>> t.insert(1)
561         >>> t.depth()
562         4
563
564         """
565         if self.root is None:
566             return 0
567         return self._depth(self.root, 0)
568
569     def height(self):
570         return self.depth()
571
572     def repr_traverse(self, padding: str, pointer: str, node: Node, has_right_sibling: bool) -> str:
573         if node is not None:
574             viz = f'\n{padding}{pointer}{node.value}'
575             if has_right_sibling:
576                 padding += "│  "
577             else:
578                 padding += '   '
579
580             pointer_right = "└──"
581             if node.right is not None:
582                 pointer_left = "├──"
583             else:
584                 pointer_left = "└──"
585
586             viz += self.repr_traverse(padding, pointer_left, node.left, node.right is not None)
587             viz += self.repr_traverse(padding, pointer_right, node.right, False)
588             return viz
589         return ""
590
591     def __repr__(self):
592         """
593         Draw the tree in ASCII.
594
595         >>> t = BinarySearchTree()
596         >>> t.insert(50)
597         >>> t.insert(25)
598         >>> t.insert(75)
599         >>> t.insert(12)
600         >>> t.insert(33)
601         >>> t.insert(88)
602         >>> t.insert(55)
603         >>> t
604         50
605         ├──25
606         │  ├──12
607         │  └──33
608         └──75
609            ├──55
610            └──88
611         """
612         if self.root is None:
613             return ""
614
615         ret = f'{self.root.value}'
616         pointer_right = "└──"
617         if self.root.right is None:
618             pointer_left = "└──"
619         else:
620             pointer_left = "├──"
621
622         ret += self.repr_traverse('', pointer_left, self.root.left, self.root.left is not None)
623         ret += self.repr_traverse('', pointer_right, self.root.right, False)
624         return ret
625
626
627 if __name__ == '__main__':
628     import doctest
629     doctest.testmod()