Let's be explicit with asserts; there was a bug in histogram
[python_utils.git] / collect / bst.py
1 #!/usr/bin/env python3
2
3 from typing import Any, Generator, List, Optional
4
5
6 class Node(object):
7     def __init__(self, value: Any) -> None:
8         """
9         Note: value can be anything as long as it is comparable.
10         Check out @functools.total_ordering.
11         """
12         self.left: Optional[Node] = None
13         self.right: Optional[Node] = None
14         self.value = value
15
16
17 class BinarySearchTree(object):
18     def __init__(self):
19         self.root = None
20         self.count = 0
21         self.traverse = None
22
23     def get_root(self) -> Optional[Node]:
24         return self.root
25
26     def insert(self, value: Any):
27         """
28         Insert something into the tree.
29
30         >>> t = BinarySearchTree()
31         >>> t.insert(10)
32         >>> t.insert(20)
33         >>> t.insert(5)
34         >>> len(t)
35         3
36
37         >>> t.get_root().value
38         10
39
40         """
41         if self.root is None:
42             self.root = Node(value)
43             self.count = 1
44         else:
45             self._insert(value, self.root)
46
47     def _insert(self, value: Any, node: Node):
48         """Insertion helper"""
49         if value < node.value:
50             if node.left is not None:
51                 self._insert(value, node.left)
52             else:
53                 node.left = Node(value)
54                 self.count += 1
55         else:
56             if node.right is not None:
57                 self._insert(value, node.right)
58             else:
59                 node.right = Node(value)
60                 self.count += 1
61
62     def __getitem__(self, value: Any) -> Optional[Node]:
63         """
64         Find an item in the tree and return its Node.  Returns
65         None if the item is not in the tree.
66
67         >>> t = BinarySearchTree()
68         >>> t[99]
69
70         >>> t.insert(10)
71         >>> t.insert(20)
72         >>> t.insert(5)
73         >>> t[10].value
74         10
75
76         >>> t[99]
77
78         """
79         if self.root is not None:
80             return self._find(value, self.root)
81         return None
82
83     def _find(self, value: Any, node: Node) -> Optional[Node]:
84         """Find helper"""
85         if value == node.value:
86             return node
87         elif value < node.value and node.left is not None:
88             return self._find(value, node.left)
89         elif value > node.value and node.right is not None:
90             return self._find(value, node.right)
91         return None
92
93     def _parent_path(
94         self, current: Optional[Node], target: Node
95     ) -> List[Optional[Node]]:
96         if current is None:
97             return [None]
98         ret: List[Optional[Node]] = [current]
99         if target.value == current.value:
100             return ret
101         elif target.value < current.value:
102             ret.extend(self._parent_path(current.left, target))
103             return ret
104         else:
105             assert target.value > current.value
106             ret.extend(self._parent_path(current.right, target))
107             return ret
108
109     def parent_path(self, node: Node) -> List[Optional[Node]]:
110         """Return a list of nodes representing the path from
111         the tree's root to the node argument.  If the node does
112         not exist in the tree for some reason, the last element
113         on the path will be None but the path will indicate the
114         ancestor path of that node were it inserted.
115
116         >>> t = BinarySearchTree()
117         >>> t.insert(50)
118         >>> t.insert(75)
119         >>> t.insert(25)
120         >>> t.insert(12)
121         >>> t.insert(33)
122         >>> t.insert(4)
123         >>> t.insert(88)
124         >>> t
125         50
126         ├──25
127         │  ├──12
128         │  │  └──4
129         │  └──33
130         └──75
131            └──88
132
133         >>> n = t[4]
134         >>> for x in t.parent_path(n):
135         ...     print(x.value)
136         50
137         25
138         12
139         4
140
141         >>> del t[4]
142         >>> for x in t.parent_path(n):
143         ...     if x is not None:
144         ...         print(x.value)
145         ...     else:
146         ...         print(x)
147         50
148         25
149         12
150         None
151
152         """
153         return self._parent_path(self.root, node)
154
155     def __delitem__(self, value: Any) -> bool:
156         """
157         Delete an item from the tree and preserve the BST property.
158
159         >>> t = BinarySearchTree()
160         >>> t.insert(50)
161         >>> t.insert(75)
162         >>> t.insert(25)
163         >>> t.insert(66)
164         >>> t.insert(22)
165         >>> t.insert(13)
166         >>> t.insert(85)
167         >>> t
168         50
169         ├──25
170         │  └──22
171         │     └──13
172         └──75
173            ├──66
174            └──85
175
176         >>> for value in t.iterate_inorder():
177         ...     print(value)
178         13
179         22
180         25
181         50
182         66
183         75
184         85
185
186         >>> del t[22]  # Note: bool result is discarded
187
188         >>> for value in t.iterate_inorder():
189         ...     print(value)
190         13
191         25
192         50
193         66
194         75
195         85
196
197         >>> t.__delitem__(13)
198         True
199         >>> for value in t.iterate_inorder():
200         ...     print(value)
201         25
202         50
203         66
204         75
205         85
206
207         >>> t.__delitem__(75)
208         True
209         >>> for value in t.iterate_inorder():
210         ...     print(value)
211         25
212         50
213         66
214         85
215         >>> t
216         50
217         ├──25
218         └──85
219            └──66
220
221         >>> t.__delitem__(99)
222         False
223
224         """
225         if self.root is not None:
226             ret = self._delete(value, None, self.root)
227             if ret:
228                 self.count -= 1
229                 if self.count == 0:
230                     self.root = None
231             return ret
232         return False
233
234     def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool:
235         """Delete helper"""
236         if node.value == value:
237             # Deleting a leaf node
238             if node.left is None and node.right is None:
239                 if parent is not None:
240                     if parent.left == node:
241                         parent.left = None
242                     else:
243                         assert parent.right == node
244                         parent.right = None
245                 return True
246
247             # Node only has a right.
248             elif node.left is None:
249                 assert node.right is not None
250                 if parent is not None:
251                     if parent.left == node:
252                         parent.left = node.right
253                     else:
254                         assert parent.right == node
255                         parent.right = node.right
256                 return True
257
258             # Node only has a left.
259             elif node.right is None:
260                 assert node.left is not None
261                 if parent is not None:
262                     if parent.left == node:
263                         parent.left = node.left
264                     else:
265                         assert parent.right == node
266                         parent.right = node.left
267                 return True
268
269             # Node has both a left and right.
270             else:
271                 assert node.left is not None and node.right is not None
272                 descendent = node.right
273                 while descendent.left is not None:
274                     descendent = descendent.left
275                 node.value = descendent.value
276                 return self._delete(node.value, node, node.right)
277         elif value < node.value and node.left is not None:
278             return self._delete(value, node, node.left)
279         elif value > node.value and node.right is not None:
280             return self._delete(value, node, node.right)
281         return False
282
283     def __len__(self):
284         """
285         Returns the count of items in the tree.
286
287         >>> t = BinarySearchTree()
288         >>> len(t)
289         0
290         >>> t.insert(50)
291         >>> len(t)
292         1
293         >>> t.__delitem__(50)
294         True
295         >>> len(t)
296         0
297         >>> t.insert(75)
298         >>> t.insert(25)
299         >>> t.insert(66)
300         >>> t.insert(22)
301         >>> t.insert(13)
302         >>> t.insert(85)
303         >>> len(t)
304         6
305
306         """
307         return self.count
308
309     def __contains__(self, value: Any) -> bool:
310         """
311         Returns True if the item is in the tree; False otherwise.
312
313         """
314         return self.__getitem__(value) is not None
315
316     def _iterate_preorder(self, node: Node):
317         yield node.value
318         if node.left is not None:
319             yield from self._iterate_preorder(node.left)
320         if node.right is not None:
321             yield from self._iterate_preorder(node.right)
322
323     def _iterate_inorder(self, node: Node):
324         if node.left is not None:
325             yield from self._iterate_inorder(node.left)
326         yield node.value
327         if node.right is not None:
328             yield from self._iterate_inorder(node.right)
329
330     def _iterate_postorder(self, node: Node):
331         if node.left is not None:
332             yield from self._iterate_postorder(node.left)
333         if node.right is not None:
334             yield from self._iterate_postorder(node.right)
335         yield node.value
336
337     def iterate_preorder(self):
338         """
339         Yield the tree's items in a preorder traversal sequence.
340
341         >>> t = BinarySearchTree()
342         >>> t.insert(50)
343         >>> t.insert(75)
344         >>> t.insert(25)
345         >>> t.insert(66)
346         >>> t.insert(22)
347         >>> t.insert(13)
348
349         >>> for value in t.iterate_preorder():
350         ...     print(value)
351         50
352         25
353         22
354         13
355         75
356         66
357
358         """
359         if self.root is not None:
360             yield from self._iterate_preorder(self.root)
361
362     def iterate_inorder(self):
363         """
364         Yield the tree's items in a preorder traversal sequence.
365
366         >>> t = BinarySearchTree()
367         >>> t.insert(50)
368         >>> t.insert(75)
369         >>> t.insert(25)
370         >>> t.insert(66)
371         >>> t.insert(22)
372         >>> t.insert(13)
373         >>> t.insert(24)
374         >>> t
375         50
376         ├──25
377         │  └──22
378         │     ├──13
379         │     └──24
380         └──75
381            └──66
382
383         >>> for value in t.iterate_inorder():
384         ...     print(value)
385         13
386         22
387         24
388         25
389         50
390         66
391         75
392
393         """
394         if self.root is not None:
395             yield from self._iterate_inorder(self.root)
396
397     def iterate_postorder(self):
398         """
399         Yield the tree's items in a preorder traversal sequence.
400
401         >>> t = BinarySearchTree()
402         >>> t.insert(50)
403         >>> t.insert(75)
404         >>> t.insert(25)
405         >>> t.insert(66)
406         >>> t.insert(22)
407         >>> t.insert(13)
408
409         >>> for value in t.iterate_postorder():
410         ...     print(value)
411         13
412         22
413         25
414         66
415         75
416         50
417
418         """
419         if self.root is not None:
420             yield from self._iterate_postorder(self.root)
421
422     def _iterate_leaves(self, node: Node):
423         if node.left is not None:
424             yield from self._iterate_leaves(node.left)
425         if node.right is not None:
426             yield from self._iterate_leaves(node.right)
427         if node.left is None and node.right is None:
428             yield node.value
429
430     def iterate_leaves(self):
431         """
432         Iterate only the leaf nodes in the tree.
433
434         >>> t = BinarySearchTree()
435         >>> t.insert(50)
436         >>> t.insert(75)
437         >>> t.insert(25)
438         >>> t.insert(66)
439         >>> t.insert(22)
440         >>> t.insert(13)
441
442         >>> for value in t.iterate_leaves():
443         ...     print(value)
444         13
445         66
446
447         """
448         if self.root is not None:
449             yield from self._iterate_leaves(self.root)
450
451     def _iterate_by_depth(self, node: Node, depth: int):
452         if depth == 0:
453             yield node.value
454         else:
455             assert depth > 0
456             if node.left is not None:
457                 yield from self._iterate_by_depth(node.left, depth - 1)
458             if node.right is not None:
459                 yield from self._iterate_by_depth(node.right, depth - 1)
460
461     def iterate_nodes_by_depth(self, depth: int) -> Generator[Node, None, None]:
462         """
463         Iterate only the leaf nodes in the tree.
464
465         >>> t = BinarySearchTree()
466         >>> t.insert(50)
467         >>> t.insert(75)
468         >>> t.insert(25)
469         >>> t.insert(66)
470         >>> t.insert(22)
471         >>> t.insert(13)
472
473         >>> for value in t.iterate_nodes_by_depth(2):
474         ...     print(value)
475         22
476         66
477
478         >>> for value in t.iterate_nodes_by_depth(3):
479         ...     print(value)
480         13
481
482         """
483         if self.root is not None:
484             yield from self._iterate_by_depth(self.root, depth)
485
486     def get_next_node(self, node: Node) -> Node:
487         """
488         Given a tree node, get the next greater node in the tree.
489
490         >>> t = BinarySearchTree()
491         >>> t.insert(50)
492         >>> t.insert(75)
493         >>> t.insert(25)
494         >>> t.insert(66)
495         >>> t.insert(22)
496         >>> t.insert(13)
497         >>> t.insert(23)
498         >>> t
499         50
500         ├──25
501         │  └──22
502         │     ├──13
503         │     └──23
504         └──75
505            └──66
506
507         >>> n = t[23]
508         >>> t.get_next_node(n).value
509         25
510
511         >>> n = t[50]
512         >>> t.get_next_node(n).value
513         66
514
515         """
516         if node.right is not None:
517             x = node.right
518             while x.left is not None:
519                 x = x.left
520             return x
521
522         path = self.parent_path(node)
523         assert path[-1] is not None
524         assert path[-1] == node
525         path = path[:-1]
526         path.reverse()
527         for ancestor in path:
528             assert ancestor is not None
529             if node != ancestor.right:
530                 return ancestor
531             node = ancestor
532         raise Exception()
533
534     def _depth(self, node: Node, sofar: int) -> int:
535         depth_left = sofar + 1
536         depth_right = sofar + 1
537         if node.left is not None:
538             depth_left = self._depth(node.left, sofar + 1)
539         if node.right is not None:
540             depth_right = self._depth(node.right, sofar + 1)
541         return max(depth_left, depth_right)
542
543     def depth(self):
544         """
545         Returns the max height (depth) of the tree in plies (edge distance
546         from root).
547
548         >>> t = BinarySearchTree()
549         >>> t.depth()
550         0
551
552         >>> t.insert(50)
553         >>> t.depth()
554         1
555
556         >>> t.insert(65)
557         >>> t.depth()
558         2
559
560         >>> t.insert(33)
561         >>> t.depth()
562         2
563
564         >>> t.insert(2)
565         >>> t.insert(1)
566         >>> t.depth()
567         4
568
569         """
570         if self.root is None:
571             return 0
572         return self._depth(self.root, 0)
573
574     def height(self):
575         return self.depth()
576
577     def repr_traverse(
578         self, padding: str, pointer: str, node: Optional[Node], has_right_sibling: bool
579     ) -> str:
580         if node is not None:
581             viz = f'\n{padding}{pointer}{node.value}'
582             if has_right_sibling:
583                 padding += "│  "
584             else:
585                 padding += '   '
586
587             pointer_right = "└──"
588             if node.right is not None:
589                 pointer_left = "├──"
590             else:
591                 pointer_left = "└──"
592
593             viz += self.repr_traverse(
594                 padding, pointer_left, node.left, node.right is not None
595             )
596             viz += self.repr_traverse(padding, pointer_right, node.right, False)
597             return viz
598         return ""
599
600     def __repr__(self):
601         """
602         Draw the tree in ASCII.
603
604         >>> t = BinarySearchTree()
605         >>> t.insert(50)
606         >>> t.insert(25)
607         >>> t.insert(75)
608         >>> t.insert(12)
609         >>> t.insert(33)
610         >>> t.insert(88)
611         >>> t.insert(55)
612         >>> t
613         50
614         ├──25
615         │  ├──12
616         │  └──33
617         └──75
618            ├──55
619            └──88
620         """
621         if self.root is None:
622             return ""
623
624         ret = f'{self.root.value}'
625         pointer_right = "└──"
626         if self.root.right is None:
627             pointer_left = "└──"
628         else:
629             pointer_left = "├──"
630
631         ret += self.repr_traverse(
632             '', pointer_left, self.root.left, self.root.left is not None
633         )
634         ret += self.repr_traverse('', pointer_right, self.root.right, False)
635         return ret
636
637
638 if __name__ == '__main__':
639     import doctest
640
641     doctest.testmod()