#!/usr/bin/env python3
-# © Copyright 2021-2022, Scott Gasch
+# © Copyright 2021-2023, Scott Gasch
"""A binary search tree implementation."""
-from typing import Any, Generator, List, Optional
+from typing import Generator, List, Optional
+from pyutils.typez.typing import Comparable
-class Node(object):
- def __init__(self, value: Any) -> None:
- """
- A BST node. Note that value can be anything as long as it
- is comparable. Check out :meth:`functools.total_ordering`
- (https://docs.python.org/3/library/functools.html#functools.total_ordering)
+
+class Node:
+ def __init__(self, value: Comparable) -> None:
+ """A BST node. Just a left and right reference along with a
+ value. Note that value can be anything as long as it
+ is :class:`Comparable` with other instances of itself.
Args:
- value: a reference to the value of the node.
+ value: a reference to the value of the node. Must be
+ :class:`Comparable` to other values.
+
"""
self.left: Optional[Node] = None
self.right: Optional[Node] = None
- self.value = value
+ self.value: Comparable = value
class BinarySearchTree(object):
return self.root
- def insert(self, value: Any) -> None:
+ def _on_insert(self, parent: Optional[Node], new: Node) -> None:
+ """This is called immediately _after_ a new node is inserted."""
+ pass
+
+ def insert(self, value: Comparable) -> None:
"""
- Insert something into the tree.
+ Insert something into the tree in :math:`O(log_2 n)` time.
Args:
value: the value to be inserted.
if self.root is None:
self.root = Node(value)
self.count = 1
+ self._on_insert(None, self.root)
else:
self._insert(value, self.root)
- def _insert(self, value: Any, node: Node):
+ def _insert(self, value: Comparable, node: Node):
"""Insertion helper"""
if value < node.value:
if node.left is not None:
else:
node.left = Node(value)
self.count += 1
+ self._on_insert(node, node.left)
else:
if node.right is not None:
self._insert(value, node.right)
else:
node.right = Node(value)
self.count += 1
+ self._on_insert(node, node.right)
- def __getitem__(self, value: Any) -> Optional[Node]:
+ def __getitem__(self, value: Comparable) -> Optional[Node]:
"""
- Find an item in the tree and return its Node. Returns
- None if the item is not in the tree.
+ Find an item in the tree and return its Node in
+ :math:`O(log_2 n)` time. Returns None if the item is not in
+ the tree.
>>> t = BinarySearchTree()
>>> t[99]
"""
if self.root is not None:
- return self._find(value, self.root)
+ return self._find_exact(value, self.root)
return None
- def _find(self, value: Any, node: Node) -> Optional[Node]:
- """Find helper"""
- if value == node.value:
+ def _find_exact(self, target: Comparable, node: Node) -> Optional[Node]:
+ """Recursively traverse the tree looking for a node with the
+ target value. Return that node if it exists, otherwise return
+ None."""
+
+ if target == node.value:
return node
- elif value < node.value and node.left is not None:
- return self._find(value, node.left)
- elif value > node.value and node.right is not None:
- return self._find(value, node.right)
+ elif target < node.value and node.left is not None:
+ return self._find_exact(target, node.left)
+ elif target > node.value and node.right is not None:
+ return self._find_exact(target, node.right)
return None
+ def _find_lowest_node_less_than_or_equal_to(
+ self, target: Comparable, node: Optional[Node]
+ ) -> Optional[Node]:
+ """Find helper that returns the lowest node that is less
+ than or equal to the target value. Returns None if target is
+ lower than the lowest node in the tree.
+
+ >>> t = BinarySearchTree()
+ >>> t.insert(50)
+ >>> t.insert(75)
+ >>> t.insert(25)
+ >>> t.insert(66)
+ >>> t.insert(22)
+ >>> t.insert(13)
+ >>> t.insert(85)
+ >>> t
+ 50
+ ├──25
+ │ └──22
+ │ └──13
+ └──75
+ ├──66
+ └──85
+
+ >>> t._find_lowest_node_less_than_or_equal_to(48, t.root).value
+ 25
+ >>> t._find_lowest_node_less_than_or_equal_to(55, t.root).value
+ 50
+ >>> t._find_lowest_node_less_than_or_equal_to(100, t.root).value
+ 85
+ >>> t._find_lowest_node_less_than_or_equal_to(24, t.root).value
+ 22
+ >>> t._find_lowest_node_less_than_or_equal_to(20, t.root).value
+ 13
+ >>> t._find_lowest_node_less_than_or_equal_to(72, t.root).value
+ 66
+ >>> t._find_lowest_node_less_than_or_equal_to(78, t.root).value
+ 75
+ >>> t._find_lowest_node_less_than_or_equal_to(12, t.root) is None
+ True
+
+ """
+
+ if not node:
+ return None
+
+ if target == node.value:
+ return node
+
+ elif target > node.value:
+ if below := self._find_lowest_node_less_than_or_equal_to(
+ target, node.right
+ ):
+ return below
+ else:
+ return node
+
+ else:
+ return self._find_lowest_node_less_than_or_equal_to(target, node.left)
+
+ def _find_lowest_node_greater_than_or_equal_to(
+ self, target: Comparable, node: Optional[Node]
+ ) -> Optional[Node]:
+ """Find helper that returns the lowest node that is greater
+ than or equal to the target value. Returns None if target is
+ higher than the greatest node in the tree.
+
+ >>> t = BinarySearchTree()
+ >>> t.insert(50)
+ >>> t.insert(75)
+ >>> t.insert(25)
+ >>> t.insert(66)
+ >>> t.insert(22)
+ >>> t.insert(13)
+ >>> t.insert(85)
+ >>> t
+ 50
+ ├──25
+ │ └──22
+ │ └──13
+ └──75
+ ├──66
+ └──85
+
+ >>> t._find_lowest_node_greater_than_or_equal_to(48, t.root).value
+ 50
+ >>> t._find_lowest_node_greater_than_or_equal_to(55, t.root).value
+ 66
+ >>> t._find_lowest_node_greater_than_or_equal_to(1, t.root).value
+ 13
+ >>> t._find_lowest_node_greater_than_or_equal_to(24, t.root).value
+ 25
+ >>> t._find_lowest_node_greater_than_or_equal_to(20, t.root).value
+ 22
+ >>> t._find_lowest_node_greater_than_or_equal_to(72, t.root).value
+ 75
+ >>> t._find_lowest_node_greater_than_or_equal_to(78, t.root).value
+ 85
+ >>> t._find_lowest_node_greater_than_or_equal_to(95, t.root) is None
+ True
+
+ """
+
+ if not node:
+ return None
+
+ if target == node.value:
+ return node
+
+ elif target > node.value:
+ return self._find_lowest_node_greater_than_or_equal_to(target, node.right)
+
+ # If target < this node's value, either this node is the
+ # answer or the answer is in this node's left subtree.
+ else:
+ if below := self._find_lowest_node_greater_than_or_equal_to(
+ target, node.left
+ ):
+ return below
+ else:
+ return node
+
def _parent_path(
self, current: Optional[Node], target: Node
) -> List[Optional[Node]]:
return ret
def parent_path(self, node: Node) -> List[Optional[Node]]:
- """Get a node's parent path.
+ """Get a node's parent path in :math:`O(log_2 n)` time.
Args:
- node: the node to check
+ node: the node whose parent path should be returned.
Returns:
a list of nodes representing the path from
- the tree's root to the node.
+ the tree's root to the given node.
.. note::
"""
return self._parent_path(self.root, node)
- def __delitem__(self, value: Any) -> bool:
+ def __delitem__(self, value: Comparable) -> bool:
"""
- Delete an item from the tree and preserve the BST property.
+ Delete an item from the tree and preserve the BST property in
+ :math:`O(log_2 n) time`.
Args:
value: the value of the node to be deleted.
└──85
└──66
+ >>> t.__delitem__(85)
+ True
+
>>> t.__delitem__(99)
False
return ret
return False
- def _delete(self, value: Any, parent: Optional[Node], node: Node) -> bool:
+ def _on_delete(self, parent: Optional[Node], deleted: Node) -> None:
+ """This is called just after deleted was deleted from the tree"""
+ pass
+
+ def _delete(self, value: Comparable, parent: Optional[Node], node: Node) -> bool:
"""Delete helper"""
if node.value == value:
+
# Deleting a leaf node
if node.left is None and node.right is None:
if parent is not None:
else:
assert parent.right == node
parent.right = None
+ self._on_delete(parent, node)
return True
# Node only has a right.
else:
assert parent.right == node
parent.right = node.right
+ self._on_delete(parent, node)
return True
# Node only has a left.
else:
assert parent.right == node
parent.right = node.left
+ self._on_delete(parent, node)
return True
- # Node has both a left and right.
+ # Node has both a left and right; get the successor node
+ # to this one and put it here (deleting the successor's
+ # old node). Because these operations are happening only
+ # in the subtree underneath of node, I'm still calling
+ # this delete an O(log_2 n) operation in the docs.
else:
assert node.left is not None and node.right is not None
- descendent = node.right
- while descendent.left is not None:
- descendent = descendent.left
- node.value = descendent.value
+ successor = self.get_next_node(node)
+ assert successor is not None
+ node.value = successor.value
return self._delete(node.value, node, node.right)
+
elif value < node.value and node.left is not None:
return self._delete(value, node, node.left)
elif value > node.value and node.right is not None:
def __len__(self):
"""
Returns:
- The count of items in the tree.
+ The count of items in the tree in :math:`O(1)` time.
>>> t = BinarySearchTree()
>>> len(t)
"""
return self.count
- def __contains__(self, value: Any) -> bool:
+ def __contains__(self, value: Comparable) -> bool:
"""
Returns:
True if the item is in the tree; False otherwise.
def iterate_leaves(self):
"""
Returns:
- A Gemerator that yielde only the leaf nodes in the
+ A Generator that yields only the leaf nodes in the
tree.
>>> t = BinarySearchTree()
if self.root is not None:
yield from self._iterate_by_depth(self.root, depth)
- def get_next_node(self, node: Node) -> Node:
+ def get_next_node(self, node: Node) -> Optional[Node]:
"""
Args:
node: the node whose next greater successor is desired
Returns:
Given a tree node, returns the next greater node in the tree.
+ If the given node is the greatest node in the tree, returns None.
>>> t = BinarySearchTree()
>>> t.insert(50)
>>> t.get_next_node(n).value
66
+ >>> n = t[75]
+ >>> t.get_next_node(n) is None
+ True
+
"""
if node.right is not None:
x = node.right
if node != ancestor.right:
return ancestor
node = ancestor
- raise Exception()
+ return None
+
+ def get_nodes_in_range_inclusive(
+ self, lower: Comparable, upper: Comparable
+ ) -> Generator[Node, None, None]:
+ """
+ Args:
+ lower: the lower bound of the desired range.
+ upper: the upper bound of the desired range.
+
+ Returns:
+ Generates a sequence of nodes in the desired range.
+
+ >>> t = BinarySearchTree()
+ >>> t.insert(50)
+ >>> t.insert(75)
+ >>> t.insert(25)
+ >>> t.insert(66)
+ >>> t.insert(22)
+ >>> t.insert(13)
+ >>> t.insert(23)
+ >>> t
+ 50
+ ├──25
+ │ └──22
+ │ ├──13
+ │ └──23
+ └──75
+ └──66
+
+ >>> for node in t.get_nodes_in_range_inclusive(21, 74):
+ ... print(node.value)
+ 22
+ 23
+ 25
+ 50
+ 66
+ """
+ node: Optional[Node] = self._find_lowest_node_greater_than_or_equal_to(
+ lower, self.root
+ )
+ while node:
+ if lower <= node.value <= upper:
+ yield node
+ node = self.get_next_node(node)
def _depth(self, node: Node, sofar: int) -> int:
depth_left = sofar + 1
"""
Returns:
The max height (depth) of the tree in plies (edge distance
- from root).
+ from root) in :math:`O(log_2 n)` time.
>>> t = BinarySearchTree()
>>> t.depth()
has_right_sibling: bool,
) -> str:
if node is not None:
- viz = f'\n{padding}{pointer}{node.value}'
+ viz = f"\n{padding}{pointer}{node.value}"
if has_right_sibling:
padding += "│ "
else:
- padding += ' '
+ padding += " "
pointer_right = "└──"
if node.right is not None:
if self.root is None:
return ""
- ret = f'{self.root.value}'
+ ret = f"{self.root.value}"
pointer_right = "└──"
if self.root.right is None:
pointer_left = "└──"
pointer_left = "├──"
ret += self.repr_traverse(
- '', pointer_left, self.root.left, self.root.left is not None
+ "", pointer_left, self.root.left, self.root.left is not None
)
- ret += self.repr_traverse('', pointer_right, self.root.right, False)
+ ret += self.repr_traverse("", pointer_right, self.root.right, False)
return ret
-if __name__ == '__main__':
+if __name__ == "__main__":
import doctest
doctest.testmod()