3 from typing import Any, Optional, List
7 def __init__(self, value: Any) -> None:
9 Note: value can be anything as long as it is comparable.
10 Check out @functools.total_ordering.
17 class BinarySearchTree(object):
23 def get_root(self) -> Optional[Node]:
26 def insert(self, value: Any):
28 Insert something into the tree.
30 >>> t = BinarySearchTree()
37 >>> t.get_root().value
42 self.root = Node(value)
45 self._insert(value, self.root)
47 def _insert(self, value: Any, node: Node):
48 """Insertion helper"""
49 if value < node.value:
50 if node.left is not None:
51 self._insert(value, node.left)
53 node.left = Node(value)
56 if node.right is not None:
57 self._insert(value, node.right)
59 node.right = Node(value)
62 def __getitem__(self, value: Any) -> Optional[Node]:
64 Find an item in the tree and return its Node. Returns
65 None if the item is not in the tree.
67 >>> t = BinarySearchTree()
79 if self.root is not None:
80 return self._find(value, self.root)
83 def _find(self, value: Any, node: Node) -> Optional[Node]:
85 if value == node.value:
87 elif (value < node.value and node.left is not None):
88 return self._find(value, node.left)
89 elif (value > node.value and node.right is not None):
90 return self._find(value, node.right)
93 def _parent_path(self, current: Node, target: Node):
97 if target.value == current.value:
99 elif target.value < current.value:
100 ret.extend(self._parent_path(current.left, target))
103 assert target.value > current.value
104 ret.extend(self._parent_path(current.right, target))
107 def parent_path(self, node: Node) -> Optional[List[Node]]:
108 """Return a list of nodes representing the path from
109 the tree's root to the node argument. If the node does
110 not exist in the tree for some reason, the last element
111 on the path will be None but the path will indicate the
112 ancestor path of that node were it inserted.
114 >>> t = BinarySearchTree()
132 >>> for x in t.parent_path(n):
140 >>> for x in t.parent_path(n):
141 ... if x is not None:
151 return self._parent_path(self.root, node)
153 def __delitem__(self, value: Any) -> bool:
155 Delete an item from the tree and preserve the BST property.
157 >>> t = BinarySearchTree()
174 >>> for value in t.iterate_inorder():
184 >>> del t[22] # Note: bool result is discarded
186 >>> for value in t.iterate_inorder():
195 >>> t.__delitem__(13)
197 >>> for value in t.iterate_inorder():
205 >>> t.__delitem__(75)
207 >>> for value in t.iterate_inorder():
219 >>> t.__delitem__(99)
223 if self.root is not None:
224 ret = self._delete(value, None, self.root)
232 def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool:
234 if node.value == value:
235 # Deleting a leaf node
236 if node.left is None and node.right is None:
237 if parent is not None:
238 if parent.left == node:
241 assert parent.right == node
245 # Node only has a right.
246 elif node.left is None:
247 assert node.right is not None
248 if parent is not None:
249 if parent.left == node:
250 parent.left = node.right
252 assert parent.right == node
253 parent.right = node.right
256 # Node only has a left.
257 elif node.right is None:
258 assert node.left is not None
259 if parent is not None:
260 if parent.left == node:
261 parent.left = node.left
263 assert parent.right == node
264 parent.right = node.left
267 # Node has both a left and right.
269 assert node.left is not None and node.right is not None
270 descendent = node.right
271 while descendent.left is not None:
272 descendent = descendent.left
273 node.value = descendent.value
274 return self._delete(node.value, node, node.right)
275 elif value < node.value and node.left is not None:
276 return self._delete(value, node, node.left)
277 elif value > node.value and node.right is not None:
278 return self._delete(value, node, node.right)
283 Returns the count of items in the tree.
285 >>> t = BinarySearchTree()
291 >>> t.__delitem__(50)
307 def __contains__(self, value: Any) -> bool:
309 Returns True if the item is in the tree; False otherwise.
312 return self.__getitem__(value) is not None
314 def _iterate_preorder(self, node: Node):
316 if node.left is not None:
317 yield from self._iterate_preorder(node.left)
318 if node.right is not None:
319 yield from self._iterate_preorder(node.right)
321 def _iterate_inorder(self, node: Node):
322 if node.left is not None:
323 yield from self._iterate_inorder(node.left)
325 if node.right is not None:
326 yield from self._iterate_inorder(node.right)
328 def _iterate_postorder(self, node: Node):
329 if node.left is not None:
330 yield from self._iterate_postorder(node.left)
331 if node.right is not None:
332 yield from self._iterate_postorder(node.right)
335 def iterate_preorder(self):
337 Yield the tree's items in a preorder traversal sequence.
339 >>> t = BinarySearchTree()
347 >>> for value in t.iterate_preorder():
357 if self.root is not None:
358 yield from self._iterate_preorder(self.root)
360 def iterate_inorder(self):
362 Yield the tree's items in a preorder traversal sequence.
364 >>> t = BinarySearchTree()
381 >>> for value in t.iterate_inorder():
392 if self.root is not None:
393 yield from self._iterate_inorder(self.root)
395 def iterate_postorder(self):
397 Yield the tree's items in a preorder traversal sequence.
399 >>> t = BinarySearchTree()
407 >>> for value in t.iterate_postorder():
417 if self.root is not None:
418 yield from self._iterate_postorder(self.root)
420 def _iterate_leaves(self, node: Node):
421 if node.left is not None:
422 yield from self._iterate_leaves(node.left)
423 if node.right is not None:
424 yield from self._iterate_leaves(node.right)
425 if node.left is None and node.right is None:
428 def iterate_leaves(self):
430 Iterate only the leaf nodes in the tree.
432 >>> t = BinarySearchTree()
440 >>> for value in t.iterate_leaves():
446 if self.root is not None:
447 yield from self._iterate_leaves(self.root)
449 def _iterate_by_depth(self, node: Node, depth: int):
454 if node.left is not None:
455 yield from self._iterate_by_depth(node.left, depth - 1)
456 if node.right is not None:
457 yield from self._iterate_by_depth(node.right, depth - 1)
459 def iterate_nodes_by_depth(self, depth: int):
461 Iterate only the leaf nodes in the tree.
463 >>> t = BinarySearchTree()
471 >>> for value in t.iterate_nodes_by_depth(2):
476 >>> for value in t.iterate_nodes_by_depth(3):
481 if self.root is not None:
482 yield from self._iterate_by_depth(self.root, depth)
484 def get_next_node(self, node: Node) -> Node:
486 Given a tree node, get the next greater node in the tree.
488 >>> t = BinarySearchTree()
506 >>> t.get_next_node(n).value
510 >>> t.get_next_node(n).value
514 if node.right is not None:
516 while x.left is not None:
520 path = self.parent_path(node)
521 assert path[-1] == node
524 for ancestor in path:
525 if node != ancestor.right:
529 def _depth(self, node: Node, sofar: int) -> int:
530 depth_left = sofar + 1
531 depth_right = sofar + 1
532 if node.left is not None:
533 depth_left = self._depth(node.left, sofar + 1)
534 if node.right is not None:
535 depth_right = self._depth(node.right, sofar + 1)
536 return max(depth_left, depth_right)
540 Returns the max height (depth) of the tree in plies (edge distance
543 >>> t = BinarySearchTree()
565 if self.root is None:
567 return self._depth(self.root, 0)
572 def repr_traverse(self, padding: str, pointer: str, node: Node, has_right_sibling: bool) -> str:
574 viz = f'\n{padding}{pointer}{node.value}'
575 if has_right_sibling:
580 pointer_right = "└──"
581 if node.right is not None:
586 viz += self.repr_traverse(padding, pointer_left, node.left, node.right is not None)
587 viz += self.repr_traverse(padding, pointer_right, node.right, False)
593 Draw the tree in ASCII.
595 >>> t = BinarySearchTree()
612 if self.root is None:
615 ret = f'{self.root.value}'
616 pointer_right = "└──"
617 if self.root.right is None:
622 ret += self.repr_traverse('', pointer_left, self.root.left, self.root.left is not None)
623 ret += self.repr_traverse('', pointer_right, self.root.right, False)
627 if __name__ == '__main__':