changes
[python_utils.git] / collect / bst.py
index b4d25b34a627797660362a16d6430fdcd6d7eceb..72a3b7738b981878b9b08eaba67bca2b33314f4f 100644 (file)
@@ -5,6 +5,10 @@ from typing import Any, Optional, List
 
 class Node(object):
     def __init__(self, value: Any) -> None:
+        """
+        Note: value can be anything as long as it is comparable.
+        Check out @functools.total_ordering.
+        """
         self.left = None
         self.right = None
         self.value = value
@@ -82,10 +86,8 @@ class BinarySearchTree(object):
             return node
         elif (value < node.value and node.left is not None):
             return self._find(value, node.left)
-        else:
-            assert value > node.value
-            if node.right is not None:
-                return self._find(value, node.right)
+        elif (value > node.value and node.right is not None):
+            return self._find(value, node.right)
         return None
 
     def _parent_path(self, current: Node, target: Node):
@@ -104,7 +106,10 @@ class BinarySearchTree(object):
 
     def parent_path(self, node: Node) -> Optional[List[Node]]:
         """Return a list of nodes representing the path from
-        the tree's root to the node argument.
+        the tree's root to the node argument.  If the node does
+        not exist in the tree for some reason, the last element
+        on the path will be None but the path will indicate the
+        ancestor path of that node were it inserted.
 
         >>> t = BinarySearchTree()
         >>> t.insert(50)
@@ -131,6 +136,17 @@ class BinarySearchTree(object):
         12
         4
 
+        >>> del t[4]
+        >>> for x in t.parent_path(n):
+        ...     if x is not None:
+        ...         print(x.value)
+        ...     else:
+        ...         print(x)
+        50
+        25
+        12
+        None
+
         """
         return self._parent_path(self.root, node)
 
@@ -138,15 +154,6 @@ class BinarySearchTree(object):
         """
         Delete an item from the tree and preserve the BST property.
 
-                            50
-                           /  \
-                         25    75
-                        /     /  \
-                      22    66    85
-                     /
-                   13
-
-
         >>> t = BinarySearchTree()
         >>> t.insert(50)
         >>> t.insert(75)
@@ -155,6 +162,14 @@ class BinarySearchTree(object):
         >>> t.insert(22)
         >>> t.insert(13)
         >>> t.insert(85)
+        >>> t
+        50
+        ├──25
+        │  └──22
+        │     └──13
+        └──75
+           ├──66
+           └──85
 
         >>> for value in t.iterate_inorder():
         ...     print(value)
@@ -195,6 +210,11 @@ class BinarySearchTree(object):
         50
         66
         85
+        >>> t
+        50
+        ├──25
+        └──85
+           └──66
 
         >>> t.__delitem__(99)
         False
@@ -551,7 +571,7 @@ class BinarySearchTree(object):
 
     def repr_traverse(self, padding: str, pointer: str, node: Node, has_right_sibling: bool) -> str:
         if node is not None:
-            self.viz += f'\n{padding}{pointer}{node.value}'
+            viz = f'\n{padding}{pointer}{node.value}'
             if has_right_sibling:
                 padding += "│  "
             else:
@@ -563,8 +583,10 @@ class BinarySearchTree(object):
             else:
                 pointer_left = "└──"
 
-            self.repr_traverse(padding, pointer_left, node.left, node.right is not None)
-            self.repr_traverse(padding, pointer_right, node.right, False)
+            viz += self.repr_traverse(padding, pointer_left, node.left, node.right is not None)
+            viz += self.repr_traverse(padding, pointer_right, node.right, False)
+            return viz
+        return ""
 
     def __repr__(self):
         """
@@ -590,16 +612,16 @@ class BinarySearchTree(object):
         if self.root is None:
             return ""
 
-        self.viz = f'{self.root.value}'
+        ret = f'{self.root.value}'
         pointer_right = "└──"
         if self.root.right is None:
             pointer_left = "└──"
         else:
             pointer_left = "├──"
 
-        self.repr_traverse('', pointer_left, self.root.left, self.root.left is not None)
-        self.repr_traverse('', pointer_right, self.root.right, False)
-        return self.viz
+        ret += self.repr_traverse('', pointer_left, self.root.left, self.root.left is not None)
+        ret += self.repr_traverse('', pointer_right, self.root.right, False)
+        return ret
 
 
 if __name__ == '__main__':