Some binary tree methods to support the unscramble progam's sparsefile
[python_utils.git] / collect / bst.py
1 #!/usr/bin/env python3
2
3 from typing import Any, Optional, List
4
5
6 class Node(object):
7     def __init__(self, value: Any) -> None:
8         self.left = None
9         self.right = None
10         self.value = value
11
12
13 class BinarySearchTree(object):
14     def __init__(self):
15         self.root = None
16         self.count = 0
17         self.traverse = None
18
19     def get_root(self) -> Optional[Node]:
20         return self.root
21
22     def insert(self, value: Any):
23         """
24         Insert something into the tree.
25
26         >>> t = BinarySearchTree()
27         >>> t.insert(10)
28         >>> t.insert(20)
29         >>> t.insert(5)
30         >>> len(t)
31         3
32
33         >>> t.get_root().value
34         10
35
36         """
37         if self.root is None:
38             self.root = Node(value)
39             self.count = 1
40         else:
41             self._insert(value, self.root)
42
43     def _insert(self, value: Any, node: Node):
44         """Insertion helper"""
45         if value < node.value:
46             if node.left is not None:
47                 self._insert(value, node.left)
48             else:
49                 node.left = Node(value)
50                 self.count += 1
51         else:
52             if node.right is not None:
53                 self._insert(value, node.right)
54             else:
55                 node.right = Node(value)
56                 self.count += 1
57
58     def __getitem__(self, value: Any) -> Optional[Node]:
59         """
60         Find an item in the tree and return its Node.  Returns
61         None if the item is not in the tree.
62
63         >>> t = BinarySearchTree()
64         >>> t[99]
65
66         >>> t.insert(10)
67         >>> t.insert(20)
68         >>> t.insert(5)
69         >>> t[10].value
70         10
71
72         >>> t[99]
73
74         """
75         if self.root is not None:
76             return self._find(value, self.root)
77         return None
78
79     def _find(self, value: Any, node: Node) -> Optional[Node]:
80         """Find helper"""
81         if value == node.value:
82             return node
83         elif (value < node.value and node.left is not None):
84             return self._find(value, node.left)
85         else:
86             assert value > node.value
87             if node.right is not None:
88                 return self._find(value, node.right)
89         return None
90
91     def _parent_path(self, current: Node, target: Node):
92         if current is None:
93             return [None]
94         ret = [current]
95         if target.value == current.value:
96             return ret
97         elif target.value < current.value:
98             ret.extend(self._parent_path(current.left, target))
99             return ret
100         else:
101             assert target.value > current.value
102             ret.extend(self._parent_path(current.right, target))
103             return ret
104
105     def parent_path(self, node: Node) -> Optional[List[Node]]:
106         """Return a list of nodes representing the path from
107         the tree's root to the node argument.
108
109         >>> t = BinarySearchTree()
110         >>> t.insert(50)
111         >>> t.insert(75)
112         >>> t.insert(25)
113         >>> t.insert(12)
114         >>> t.insert(33)
115         >>> t.insert(4)
116         >>> t.insert(88)
117         >>> t
118         50
119         ├──25
120         │  ├──12
121         │  │  └──4
122         │  └──33
123         └──75
124            └──88
125
126         >>> n = t[4]
127         >>> for x in t.parent_path(n):
128         ...     print(x.value)
129         50
130         25
131         12
132         4
133
134         """
135         return self._parent_path(self.root, node)
136
137     def __delitem__(self, value: Any) -> bool:
138         """
139         Delete an item from the tree and preserve the BST property.
140
141                             50
142                            /  \
143                          25    75
144                         /     /  \
145                       22    66    85
146                      /
147                    13
148
149
150         >>> t = BinarySearchTree()
151         >>> t.insert(50)
152         >>> t.insert(75)
153         >>> t.insert(25)
154         >>> t.insert(66)
155         >>> t.insert(22)
156         >>> t.insert(13)
157         >>> t.insert(85)
158
159         >>> for value in t.iterate_inorder():
160         ...     print(value)
161         13
162         22
163         25
164         50
165         66
166         75
167         85
168
169         >>> del t[22]  # Note: bool result is discarded
170
171         >>> for value in t.iterate_inorder():
172         ...     print(value)
173         13
174         25
175         50
176         66
177         75
178         85
179
180         >>> t.__delitem__(13)
181         True
182         >>> for value in t.iterate_inorder():
183         ...     print(value)
184         25
185         50
186         66
187         75
188         85
189
190         >>> t.__delitem__(75)
191         True
192         >>> for value in t.iterate_inorder():
193         ...     print(value)
194         25
195         50
196         66
197         85
198
199         >>> t.__delitem__(99)
200         False
201
202         """
203         if self.root is not None:
204             ret = self._delete(value, None, self.root)
205             if ret:
206                 self.count -= 1
207                 if self.count == 0:
208                     self.root = None
209             return ret
210         return False
211
212     def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool:
213         """Delete helper"""
214         if node.value == value:
215             # Deleting a leaf node
216             if node.left is None and node.right is None:
217                 if parent is not None:
218                     if parent.left == node:
219                         parent.left = None
220                     else:
221                         assert parent.right == node
222                         parent.right = None
223                 return True
224
225             # Node only has a right.
226             elif node.left is None:
227                 assert node.right is not None
228                 if parent is not None:
229                     if parent.left == node:
230                         parent.left = node.right
231                     else:
232                         assert parent.right == node
233                         parent.right = node.right
234                 return True
235
236             # Node only has a left.
237             elif node.right is None:
238                 assert node.left is not None
239                 if parent is not None:
240                     if parent.left == node:
241                         parent.left = node.left
242                     else:
243                         assert parent.right == node
244                         parent.right = node.left
245                 return True
246
247             # Node has both a left and right.
248             else:
249                 assert node.left is not None and node.right is not None
250                 descendent = node.right
251                 while descendent.left is not None:
252                     descendent = descendent.left
253                 node.value = descendent.value
254                 return self._delete(node.value, node, node.right)
255         elif value < node.value and node.left is not None:
256             return self._delete(value, node, node.left)
257         elif value > node.value and node.right is not None:
258             return self._delete(value, node, node.right)
259         return False
260
261     def __len__(self):
262         """
263         Returns the count of items in the tree.
264
265         >>> t = BinarySearchTree()
266         >>> len(t)
267         0
268         >>> t.insert(50)
269         >>> len(t)
270         1
271         >>> t.__delitem__(50)
272         True
273         >>> len(t)
274         0
275         >>> t.insert(75)
276         >>> t.insert(25)
277         >>> t.insert(66)
278         >>> t.insert(22)
279         >>> t.insert(13)
280         >>> t.insert(85)
281         >>> len(t)
282         6
283
284         """
285         return self.count
286
287     def __contains__(self, value: Any) -> bool:
288         """
289         Returns True if the item is in the tree; False otherwise.
290
291         """
292         return self.__getitem__(value) is not None
293
294     def _iterate_preorder(self, node: Node):
295         yield node.value
296         if node.left is not None:
297             yield from self._iterate_preorder(node.left)
298         if node.right is not None:
299             yield from self._iterate_preorder(node.right)
300
301     def _iterate_inorder(self, node: Node):
302         if node.left is not None:
303             yield from self._iterate_inorder(node.left)
304         yield node.value
305         if node.right is not None:
306             yield from self._iterate_inorder(node.right)
307
308     def _iterate_postorder(self, node: Node):
309         if node.left is not None:
310             yield from self._iterate_postorder(node.left)
311         if node.right is not None:
312             yield from self._iterate_postorder(node.right)
313         yield node.value
314
315     def iterate_preorder(self):
316         """
317         Yield the tree's items in a preorder traversal sequence.
318
319         >>> t = BinarySearchTree()
320         >>> t.insert(50)
321         >>> t.insert(75)
322         >>> t.insert(25)
323         >>> t.insert(66)
324         >>> t.insert(22)
325         >>> t.insert(13)
326
327         >>> for value in t.iterate_preorder():
328         ...     print(value)
329         50
330         25
331         22
332         13
333         75
334         66
335
336         """
337         if self.root is not None:
338             yield from self._iterate_preorder(self.root)
339
340     def iterate_inorder(self):
341         """
342         Yield the tree's items in a preorder traversal sequence.
343
344         >>> t = BinarySearchTree()
345         >>> t.insert(50)
346         >>> t.insert(75)
347         >>> t.insert(25)
348         >>> t.insert(66)
349         >>> t.insert(22)
350         >>> t.insert(13)
351         >>> t.insert(24)
352         >>> t
353         50
354         ├──25
355         │  └──22
356         │     ├──13
357         │     └──24
358         └──75
359            └──66
360
361         >>> for value in t.iterate_inorder():
362         ...     print(value)
363         13
364         22
365         24
366         25
367         50
368         66
369         75
370
371         """
372         if self.root is not None:
373             yield from self._iterate_inorder(self.root)
374
375     def iterate_postorder(self):
376         """
377         Yield the tree's items in a preorder traversal sequence.
378
379         >>> t = BinarySearchTree()
380         >>> t.insert(50)
381         >>> t.insert(75)
382         >>> t.insert(25)
383         >>> t.insert(66)
384         >>> t.insert(22)
385         >>> t.insert(13)
386
387         >>> for value in t.iterate_postorder():
388         ...     print(value)
389         13
390         22
391         25
392         66
393         75
394         50
395
396         """
397         if self.root is not None:
398             yield from self._iterate_postorder(self.root)
399
400     def _iterate_leaves(self, node: Node):
401         if node.left is not None:
402             yield from self._iterate_leaves(node.left)
403         if node.right is not None:
404             yield from self._iterate_leaves(node.right)
405         if node.left is None and node.right is None:
406             yield node.value
407
408     def iterate_leaves(self):
409         """
410         Iterate only the leaf nodes in the tree.
411
412         >>> t = BinarySearchTree()
413         >>> t.insert(50)
414         >>> t.insert(75)
415         >>> t.insert(25)
416         >>> t.insert(66)
417         >>> t.insert(22)
418         >>> t.insert(13)
419
420         >>> for value in t.iterate_leaves():
421         ...     print(value)
422         13
423         66
424
425         """
426         if self.root is not None:
427             yield from self._iterate_leaves(self.root)
428
429     def _iterate_by_depth(self, node: Node, depth: int):
430         if depth == 0:
431             yield node.value
432         else:
433             assert depth > 0
434             if node.left is not None:
435                 yield from self._iterate_by_depth(node.left, depth - 1)
436             if node.right is not None:
437                 yield from self._iterate_by_depth(node.right, depth - 1)
438
439     def iterate_nodes_by_depth(self, depth: int):
440         """
441         Iterate only the leaf nodes in the tree.
442
443         >>> t = BinarySearchTree()
444         >>> t.insert(50)
445         >>> t.insert(75)
446         >>> t.insert(25)
447         >>> t.insert(66)
448         >>> t.insert(22)
449         >>> t.insert(13)
450
451         >>> for value in t.iterate_nodes_by_depth(2):
452         ...     print(value)
453         22
454         66
455
456         >>> for value in t.iterate_nodes_by_depth(3):
457         ...     print(value)
458         13
459
460         """
461         if self.root is not None:
462             yield from self._iterate_by_depth(self.root, depth)
463
464     def get_next_node(self, node: Node) -> Node:
465         """
466         Given a tree node, get the next greater node in the tree.
467
468         >>> t = BinarySearchTree()
469         >>> t.insert(50)
470         >>> t.insert(75)
471         >>> t.insert(25)
472         >>> t.insert(66)
473         >>> t.insert(22)
474         >>> t.insert(13)
475         >>> t.insert(23)
476         >>> t
477         50
478         ├──25
479         │  └──22
480         │     ├──13
481         │     └──23
482         └──75
483            └──66
484
485         >>> n = t[23]
486         >>> t.get_next_node(n).value
487         25
488
489         >>> n = t[50]
490         >>> t.get_next_node(n).value
491         66
492
493         """
494         if node.right is not None:
495             x = node.right
496             while x.left is not None:
497                 x = x.left
498             return x
499
500         path = self.parent_path(node)
501         assert path[-1] == node
502         path = path[:-1]
503         path.reverse()
504         for ancestor in path:
505             if node != ancestor.right:
506                 return ancestor
507             node = ancestor
508
509     def _depth(self, node: Node, sofar: int) -> int:
510         depth_left = sofar + 1
511         depth_right = sofar + 1
512         if node.left is not None:
513             depth_left = self._depth(node.left, sofar + 1)
514         if node.right is not None:
515             depth_right = self._depth(node.right, sofar + 1)
516         return max(depth_left, depth_right)
517
518     def depth(self):
519         """
520         Returns the max height (depth) of the tree in plies (edge distance
521         from root).
522
523         >>> t = BinarySearchTree()
524         >>> t.depth()
525         0
526
527         >>> t.insert(50)
528         >>> t.depth()
529         1
530
531         >>> t.insert(65)
532         >>> t.depth()
533         2
534
535         >>> t.insert(33)
536         >>> t.depth()
537         2
538
539         >>> t.insert(2)
540         >>> t.insert(1)
541         >>> t.depth()
542         4
543
544         """
545         if self.root is None:
546             return 0
547         return self._depth(self.root, 0)
548
549     def height(self):
550         return self.depth()
551
552     def repr_traverse(self, padding: str, pointer: str, node: Node, has_right_sibling: bool) -> str:
553         if node is not None:
554             self.viz += f'\n{padding}{pointer}{node.value}'
555             if has_right_sibling:
556                 padding += "│  "
557             else:
558                 padding += '   '
559
560             pointer_right = "└──"
561             if node.right is not None:
562                 pointer_left = "├──"
563             else:
564                 pointer_left = "└──"
565
566             self.repr_traverse(padding, pointer_left, node.left, node.right is not None)
567             self.repr_traverse(padding, pointer_right, node.right, False)
568
569     def __repr__(self):
570         """
571         Draw the tree in ASCII.
572
573         >>> t = BinarySearchTree()
574         >>> t.insert(50)
575         >>> t.insert(25)
576         >>> t.insert(75)
577         >>> t.insert(12)
578         >>> t.insert(33)
579         >>> t.insert(88)
580         >>> t.insert(55)
581         >>> t
582         50
583         ├──25
584         │  ├──12
585         │  └──33
586         └──75
587            ├──55
588            └──88
589         """
590         if self.root is None:
591             return ""
592
593         self.viz = f'{self.root.value}'
594         pointer_right = "└──"
595         if self.root.right is None:
596             pointer_left = "└──"
597         else:
598             pointer_left = "├──"
599
600         self.repr_traverse('', pointer_left, self.root.left, self.root.left is not None)
601         self.repr_traverse('', pointer_right, self.root.right, False)
602         return self.viz
603
604
605 if __name__ == '__main__':
606     import doctest
607     doctest.testmod()