Initial revision
[python_utils.git] / logical_search.py
1 #!/usr/bin/env python3
2
3 from __future__ import annotations
4
5 from collections import defaultdict
6 import enum
7 import sys
8 from typing import (
9     Any,
10     Dict,
11     List,
12     NamedTuple,
13     Optional,
14     Set,
15     Sequence,
16     Tuple,
17     Union,
18 )
19
20
21 class ParseError(Exception):
22     """An error encountered while parsing a logical search expression."""
23
24     def __init__(self, message: str):
25         self.message = message
26
27
28 class Document(NamedTuple):
29     """A tuple representing a searchable document."""
30
31     docid: str  # a unique idenfier for the document
32     tags: Set[str]  # an optional set of tags
33     properties: List[
34         Tuple[str, str]
35     ]  # an optional set of key->value properties
36     reference: Any  # an optional reference to something else
37
38
39 class Operation(enum.Enum):
40     """A logical search query operation."""
41
42     QUERY = 1
43     CONJUNCTION = 2
44     DISJUNCTION = 3
45     INVERSION = 4
46
47     @staticmethod
48     def from_token(token: str):
49         table = {
50             "not": Operation.INVERSION,
51             "and": Operation.CONJUNCTION,
52             "or": Operation.DISJUNCTION,
53         }
54         return table.get(token, None)
55
56     def num_operands(self) -> Optional[int]:
57         table = {
58             Operation.INVERSION: 1,
59             Operation.CONJUNCTION: 2,
60             Operation.DISJUNCTION: 2,
61         }
62         return table.get(self, None)
63
64
65 class Corpus(object):
66     """A collection of searchable documents."""
67
68     def __init__(self) -> None:
69         self.docids_by_tag: Dict[str, Set[str]] = defaultdict(set)
70         self.docids_by_property: Dict[Tuple[str, str], Set[str]] = defaultdict(
71             set
72         )
73         self.docids_with_property: Dict[str, Set[str]] = defaultdict(set)
74         self.documents_by_docid: Dict[str, Document] = {}
75
76     def add_doc(self, doc: Document) -> None:
77         """Add a new Document to the Corpus.  Each Document must have a
78         distinct docid that will serve as its primary identifier.  If
79         the same Document is added multiple times, only the most
80         recent addition is indexed.  If two distinct documents with
81         the same docid are added, the latter klobbers the former in the
82         indexes.
83
84         Each Document may have an optional set of tags which can be
85         used later in expressions to the query method.
86
87         Each Document may have an optional list of key->value tuples
88         which can be used later in expressions to the query method.
89
90         Document includes a user-defined "reference" field which is
91         never interpreted by this module.  This is meant to allow easy
92         mapping between Documents in this corpus and external objects
93         they may represent.
94         """
95
96         if doc.docid in self.documents_by_docid:
97             # Handle collisions; assume that we are re-indexing the
98             # same document so remove it from the indexes before
99             # adding it back again.
100             colliding_doc = self.documents_by_docid[doc.docid]
101             assert colliding_doc.docid == doc.docid
102             del self.documents_by_docid[doc.docid]
103             for tag in colliding_doc.tags:
104                 self.docids_by_tag[tag].remove(doc.docid)
105             for key, value in colliding_doc.properties:
106                 self.docids_by_property[(key, value)].remove(doc.docid)
107                 self.docids_with_property[key].remove(doc.docid)
108
109         # Index the new Document
110         assert doc.docid not in self.documents_by_docid
111         self.documents_by_docid[doc.docid] = doc
112         for tag in doc.tags:
113             self.docids_by_tag[tag].add(doc.docid)
114         for key, value in doc.properties:
115             self.docids_by_property[(key, value)].add(doc.docid)
116             self.docids_with_property[key].add(doc.docid)
117
118     def get_docids_by_exact_tag(self, tag: str) -> Set[str]:
119         """Return the set of docids that have a particular tag."""
120
121         return self.docids_by_tag[tag]
122
123     def get_docids_by_searching_tags(self, tag: str) -> Set[str]:
124         """Return the set of docids with a tag that contains a str"""
125
126         ret = set()
127         for search_tag in self.docids_by_tag:
128             if tag in search_tag:
129                 for docid in self.docids_by_tag[search_tag]:
130                     ret.add(docid)
131         return ret
132
133     def get_docids_with_property(self, key: str) -> Set[str]:
134         """Return the set of docids that have a particular property no matter
135         what that property's value.
136         """
137
138         return self.docids_with_property[key]
139
140     def get_docids_by_property(self, key: str, value: str) -> Set[str]:
141         """Return the set of docids that have a particular property with a
142         particular value..
143         """
144
145         return self.docids_by_property[(key, value)]
146
147     def invert_docid_set(self, original: Set[str]) -> Set[str]:
148         """Invert a set of docids."""
149
150         return set(
151             [
152                 docid
153                 for docid in self.documents_by_docid.keys()
154                 if docid not in original
155             ]
156         )
157
158     def get_doc(self, docid: str) -> Optional[Document]:
159         """Given a docid, retrieve the previously added Document."""
160
161         return self.documents_by_docid.get(docid, None)
162
163     def query(self, query: str) -> Optional[Set[str]]:
164         """Query the corpus for documents that match a logical expression.
165         Returns a (potentially empty) set of docids for the matching
166         (previously added) documents or None on error.
167
168         e.g.
169
170         tag1 and tag2 and not tag3
171
172         (tag1 or tag2) and (tag3 or tag4)
173
174         (tag1 and key2:value2) or (tag2 and key1:value1)
175
176         key:*
177
178         tag1 and key:*
179         """
180
181         try:
182             root = self._parse_query(query)
183         except ParseError as e:
184             print(e.message, file=sys.stderr)
185             return None
186         return root.eval()
187
188     def _parse_query(self, query: str):
189         """Internal parse helper; prefer to use query instead."""
190
191         parens = set(["(", ")"])
192         and_or = set(["and", "or"])
193
194         def operator_precedence(token: str) -> Optional[int]:
195             table = {
196                 "(": 4,  # higher
197                 ")": 4,
198                 "not": 3,
199                 "and": 2,
200                 "or": 1,  # lower
201             }
202             return table.get(token, None)
203
204         def is_operator(token: str) -> bool:
205             return operator_precedence(token) is not None
206
207         def lex(query: str):
208             query = query.lower()
209             tokens = query.split()
210             for token in tokens:
211                 # Handle ( and ) operators stuck to the ends of tokens
212                 # that split() doesn't understand.
213                 if len(token) > 1:
214                     first = token[0]
215                     if first in parens:
216                         tail = token[1:]
217                         yield first
218                         token = tail
219                     last = token[-1]
220                     if last in parens:
221                         head = token[0:-1]
222                         yield head
223                         token = last
224                 yield token
225
226         def evaluate(corpus: Corpus, stack: List[str]):
227             node_stack: List[Node] = []
228             for token in stack:
229                 node = None
230                 if not is_operator(token):
231                     node = Node(corpus, Operation.QUERY, [token])
232                 else:
233                     args = []
234                     operation = Operation.from_token(token)
235                     operand_count = operation.num_operands()
236                     if len(node_stack) < operand_count:
237                         raise ParseError(
238                             f"Incorrect number of operations for {operation}"
239                         )
240                     for _ in range(operation.num_operands()):
241                         args.append(node_stack.pop())
242                     node = Node(corpus, operation, args)
243                 node_stack.append(node)
244             return node_stack[0]
245
246         output_stack = []
247         operator_stack = []
248         for token in lex(query):
249             if not is_operator(token):
250                 output_stack.append(token)
251                 continue
252
253             # token is an operator...
254             if token == "(":
255                 operator_stack.append(token)
256             elif token == ")":
257                 ok = False
258                 while len(operator_stack) > 0:
259                     pop_operator = operator_stack.pop()
260                     if pop_operator != "(":
261                         output_stack.append(pop_operator)
262                     else:
263                         ok = True
264                         break
265                 if not ok:
266                     raise ParseError(
267                         "Unbalanced parenthesis in query expression"
268                     )
269
270             # and, or, not
271             else:
272                 my_precedence = operator_precedence(token)
273                 if my_precedence is None:
274                     raise ParseError(f"Unknown operator: {token}")
275                 while len(operator_stack) > 0:
276                     peek_operator = operator_stack[-1]
277                     if not is_operator(peek_operator) or peek_operator == "(":
278                         break
279                     peek_precedence = operator_precedence(peek_operator)
280                     if peek_precedence is None:
281                         raise ParseError("Internal error")
282                     if (
283                         (peek_precedence < my_precedence)
284                         or (peek_precedence == my_precedence)
285                         and (peek_operator not in and_or)
286                     ):
287                         break
288                     output_stack.append(operator_stack.pop())
289                 operator_stack.append(token)
290         while len(operator_stack) > 0:
291             token = operator_stack.pop()
292             if token in parens:
293                 raise ParseError("Unbalanced parenthesis in query expression")
294             output_stack.append(token)
295         return evaluate(self, output_stack)
296
297
298 class Node(object):
299     """A query AST node."""
300
301     def __init__(
302         self,
303         corpus: Corpus,
304         op: Operation,
305         operands: Sequence[Union[Node, str]],
306     ):
307         self.corpus = corpus
308         self.op = op
309         self.operands = operands
310
311     def eval(self) -> Set[str]:
312         """Evaluate this node."""
313
314         evaled_operands: List[Union[Set[str], str]] = []
315         for operand in self.operands:
316             if isinstance(operand, Node):
317                 evaled_operands.append(operand.eval())
318             elif isinstance(operand, str):
319                 evaled_operands.append(operand)
320             else:
321                 raise ParseError(f"Unexpected operand: {operand}")
322
323         retval = set()
324         if self.op is Operation.QUERY:
325             for tag in evaled_operands:
326                 if isinstance(tag, str):
327                     if ":" in tag:
328                         try:
329                             key, value = tag.split(":")
330                         except ValueError as v:
331                             raise ParseError(
332                                 f'Invalid key:value syntax at "{tag}"'
333                             ) from v
334                         if value == "*":
335                             r = self.corpus.get_docids_with_property(key)
336                         else:
337                             r = self.corpus.get_docids_by_property(key, value)
338                     else:
339                         r = self.corpus.get_docids_by_exact_tag(tag)
340                     retval.update(r)
341                 else:
342                     raise ParseError(f"Unexpected query {tag}")
343         elif self.op is Operation.DISJUNCTION:
344             if len(evaled_operands) != 2:
345                 raise ParseError(
346                     "Operation.DISJUNCTION (or) expects two operands."
347                 )
348             retval.update(evaled_operands[0])
349             retval.update(evaled_operands[1])
350         elif self.op is Operation.CONJUNCTION:
351             if len(evaled_operands) != 2:
352                 raise ParseError(
353                     "Operation.CONJUNCTION (and) expects two operands."
354                 )
355             retval.update(evaled_operands[0])
356             retval = retval.intersection(evaled_operands[1])
357         elif self.op is Operation.INVERSION:
358             if len(evaled_operands) != 1:
359                 raise ParseError(
360                     "Operation.INVERSION (not) expects one operand."
361                 )
362             _ = evaled_operands[0]
363             if isinstance(_, set):
364                 retval.update(self.corpus.invert_docid_set(_))
365             else:
366                 raise ParseError(f"Unexpected negation operand {_} ({type(_)})")
367         return retval