changes
[python_utils.git] / collect / trie.py
index b9a5a1adbcd91cb96b593ac62b9eb5db7bdb6cc5..3e4c9172fbbf3b01202f6c9ccc5b2d4ff607fcc1 100644 (file)
@@ -15,6 +15,7 @@ class Trie(object):
         self.root = {}
         self.end = "~END~"
         self.length = 0
+        self.viz = ''
 
     def insert(self, item: Sequence[Any]):
         """
@@ -240,7 +241,37 @@ class Trie(object):
             return None
         return [x for x in node if x != self.end]
 
-    def repr_recursive(self, node, delimiter):
+    def repr_fancy(self, padding: str, pointer: str, parent: str, node: Any, has_sibling: bool):
+        if node is None:
+            return
+        if node is not self.root:
+            ret = f'\n{padding}{pointer}'
+            if has_sibling:
+                padding += '│  '
+            else:
+                padding += '   '
+        else:
+            ret = f'{pointer}'
+
+        child_count = 0
+        for child in node:
+            if child != self.end:
+                child_count += 1
+
+        for child in node:
+            if child != self.end:
+                if child_count > 1:
+                    pointer = "├──"
+                    has_sibling = True
+                else:
+                    pointer = "└──"
+                    has_sibling = False
+                pointer += f'{child}'
+                child_count -= 1
+                ret += self.repr_fancy(padding, pointer, node, node[child], has_sibling)
+        return ret
+
+    def repr_brief(self, node, delimiter):
         """
         A friendly string representation of the contents of the Trie.
 
@@ -249,10 +280,8 @@ class Trie(object):
         >>> t.insert([10, 0, 0, 2])
         >>> t.insert([10, 10, 10, 1])
         >>> t.insert([10, 10, 10, 2])
-        >>> t.repr_recursive(t.root, '.')
+        >>> t.repr_brief(t.root, '.')
         '10.[0.0.[1, 2], 10.10.[1, 2]]'
-        >>> print(t)
-        10[00[1, 2], 1010[1, 2]]
 
         """
         child_count = 0
@@ -260,7 +289,7 @@ class Trie(object):
         for child in node:
             if child != self.end:
                 child_count += 1
-                child_rep = self.repr_recursive(node[child], delimiter)
+                child_rep = self.repr_brief(node[child], delimiter)
                 if len(child_rep) > 0:
                     my_rep += str(child) + delimiter + child_rep + ", "
                 else:
@@ -274,7 +303,7 @@ class Trie(object):
     def __repr__(self):
         """
         A friendly string representation of the contents of the Trie.  Under
-        the covers uses repr_recursive with no delimiter
+        the covers uses repr_fancy.
 
         >>> t = Trie()
         >>> t.insert([10, 0, 0, 1])
@@ -282,10 +311,19 @@ class Trie(object):
         >>> t.insert([10, 10, 10, 1])
         >>> t.insert([10, 10, 10, 2])
         >>> print(t)
-        10[00[1, 2], 1010[1, 2]]
+        *
+        └──10
+           ├──0
+           │  └──0
+           │     ├──1
+           │     └──2
+           └──10
+              └──10
+                 ├──1
+                 └──2
 
         """
-        return self.repr_recursive(self.root, '')
+        return self.repr_fancy('', '*', self.root, self.root, False)
 
 
 if __name__ == '__main__':