805ec223010b93b2a1bf68e1fdee9467daac14aa
[python_utils.git] / logical_search.py
1 #!/usr/bin/env python3
2
3 from __future__ import annotations
4
5 from collections import defaultdict
6 import enum
7 import sys
8 from typing import (
9     Any,
10     Dict,
11     List,
12     NamedTuple,
13     Optional,
14     Set,
15     Sequence,
16     Tuple,
17     Union,
18 )
19
20
21 class ParseError(Exception):
22     """An error encountered while parsing a logical search expression."""
23
24     def __init__(self, message: str):
25         self.message = message
26
27
28 class Document(NamedTuple):
29     """A tuple representing a searchable document."""
30
31     docid: str  # a unique idenfier for the document
32     tags: Set[str]  # an optional set of tags
33     properties: List[
34         Tuple[str, str]
35     ]  # an optional set of key->value properties
36     reference: Any  # an optional reference to something else
37
38
39 class Operation(enum.Enum):
40     """A logical search query operation."""
41
42     QUERY = 1
43     CONJUNCTION = 2
44     DISJUNCTION = 3
45     INVERSION = 4
46
47     @staticmethod
48     def from_token(token: str):
49         table = {
50             "not": Operation.INVERSION,
51             "and": Operation.CONJUNCTION,
52             "or": Operation.DISJUNCTION,
53         }
54         return table.get(token, None)
55
56     def num_operands(self) -> Optional[int]:
57         table = {
58             Operation.INVERSION: 1,
59             Operation.CONJUNCTION: 2,
60             Operation.DISJUNCTION: 2,
61         }
62         return table.get(self, None)
63
64
65 class Corpus(object):
66     """A collection of searchable documents.
67
68     >>> c = Corpus()
69     >>> c.add_doc(Document(
70     ...                    docid=1,
71     ...                    tags=set(['urgent', 'important']),
72     ...                    properties=[
73     ...                                ('author', 'Scott'),
74     ...                                ('subject', 'your anniversary')
75     ...                    ],
76     ...                    reference=None,
77     ...                   )
78     ...          )
79     >>> c.add_doc(Document(
80     ...                    docid=2,
81     ...                    tags=set(['important']),
82     ...                    properties=[
83     ...                                ('author', 'Joe'),
84     ...                                ('subject', 'your performance at work')
85     ...                    ],
86     ...                    reference=None,
87     ...                   )
88     ...          )
89     >>> c.query('author:Scott and important')
90     {1}
91     """
92
93     def __init__(self) -> None:
94         self.docids_by_tag: Dict[str, Set[str]] = defaultdict(set)
95         self.docids_by_property: Dict[Tuple[str, str], Set[str]] = defaultdict(
96             set
97         )
98         self.docids_with_property: Dict[str, Set[str]] = defaultdict(set)
99         self.documents_by_docid: Dict[str, Document] = {}
100
101     def add_doc(self, doc: Document) -> None:
102         """Add a new Document to the Corpus.  Each Document must have a
103         distinct docid that will serve as its primary identifier.  If
104         the same Document is added multiple times, only the most
105         recent addition is indexed.  If two distinct documents with
106         the same docid are added, the latter klobbers the former in the
107         indexes.
108
109         Each Document may have an optional set of tags which can be
110         used later in expressions to the query method.
111
112         Each Document may have an optional list of key->value tuples
113         which can be used later in expressions to the query method.
114
115         Document includes a user-defined "reference" field which is
116         never interpreted by this module.  This is meant to allow easy
117         mapping between Documents in this corpus and external objects
118         they may represent.
119         """
120
121         if doc.docid in self.documents_by_docid:
122             # Handle collisions; assume that we are re-indexing the
123             # same document so remove it from the indexes before
124             # adding it back again.
125             colliding_doc = self.documents_by_docid[doc.docid]
126             assert colliding_doc.docid == doc.docid
127             del self.documents_by_docid[doc.docid]
128             for tag in colliding_doc.tags:
129                 self.docids_by_tag[tag].remove(doc.docid)
130             for key, value in colliding_doc.properties:
131                 self.docids_by_property[(key, value)].remove(doc.docid)
132                 self.docids_with_property[key].remove(doc.docid)
133
134         # Index the new Document
135         assert doc.docid not in self.documents_by_docid
136         self.documents_by_docid[doc.docid] = doc
137         for tag in doc.tags:
138             self.docids_by_tag[tag].add(doc.docid)
139         for key, value in doc.properties:
140             self.docids_by_property[(key, value)].add(doc.docid)
141             self.docids_with_property[key].add(doc.docid)
142
143     def get_docids_by_exact_tag(self, tag: str) -> Set[str]:
144         """Return the set of docids that have a particular tag."""
145
146         return self.docids_by_tag[tag]
147
148     def get_docids_by_searching_tags(self, tag: str) -> Set[str]:
149         """Return the set of docids with a tag that contains a str"""
150
151         ret = set()
152         for search_tag in self.docids_by_tag:
153             if tag in search_tag:
154                 for docid in self.docids_by_tag[search_tag]:
155                     ret.add(docid)
156         return ret
157
158     def get_docids_with_property(self, key: str) -> Set[str]:
159         """Return the set of docids that have a particular property no matter
160         what that property's value.
161
162         """
163         return self.docids_with_property[key]
164
165     def get_docids_by_property(self, key: str, value: str) -> Set[str]:
166         """Return the set of docids that have a particular property with a
167         particular value..
168
169         """
170         return self.docids_by_property[(key, value)]
171
172     def invert_docid_set(self, original: Set[str]) -> Set[str]:
173         """Invert a set of docids."""
174
175         return set(
176             [
177                 docid
178                 for docid in self.documents_by_docid.keys()
179                 if docid not in original
180             ]
181         )
182
183     def get_doc(self, docid: str) -> Optional[Document]:
184         """Given a docid, retrieve the previously added Document."""
185
186         return self.documents_by_docid.get(docid, None)
187
188     def query(self, query: str) -> Optional[Set[str]]:
189         """Query the corpus for documents that match a logical expression.
190         Returns a (potentially empty) set of docids for the matching
191         (previously added) documents or None on error.
192
193         e.g.
194
195         tag1 and tag2 and not tag3
196
197         (tag1 or tag2) and (tag3 or tag4)
198
199         (tag1 and key2:value2) or (tag2 and key1:value1)
200
201         key:*
202
203         tag1 and key:*
204         """
205
206         try:
207             root = self._parse_query(query)
208         except ParseError as e:
209             print(e.message, file=sys.stderr)
210             return None
211         return root.eval()
212
213     def _parse_query(self, query: str):
214         """Internal parse helper; prefer to use query instead."""
215
216         parens = set(["(", ")"])
217         and_or = set(["and", "or"])
218
219         def operator_precedence(token: str) -> Optional[int]:
220             table = {
221                 "(": 4,  # higher
222                 ")": 4,
223                 "not": 3,
224                 "and": 2,
225                 "or": 1,  # lower
226             }
227             return table.get(token, None)
228
229         def is_operator(token: str) -> bool:
230             return operator_precedence(token) is not None
231
232         def lex(query: str):
233             tokens = query.split()
234             for token in tokens:
235                 # Handle ( and ) operators stuck to the ends of tokens
236                 # that split() doesn't understand.
237                 if len(token) > 1:
238                     first = token[0]
239                     if first in parens:
240                         tail = token[1:]
241                         yield first
242                         token = tail
243                     last = token[-1]
244                     if last in parens:
245                         head = token[0:-1]
246                         yield head
247                         token = last
248                 yield token
249
250         def evaluate(corpus: Corpus, stack: List[str]):
251             node_stack: List[Node] = []
252             for token in stack:
253                 node = None
254                 if not is_operator(token):
255                     node = Node(corpus, Operation.QUERY, [token])
256                 else:
257                     args = []
258                     operation = Operation.from_token(token)
259                     operand_count = operation.num_operands()
260                     if len(node_stack) < operand_count:
261                         raise ParseError(
262                             f"Incorrect number of operations for {operation}"
263                         )
264                     for _ in range(operation.num_operands()):
265                         args.append(node_stack.pop())
266                     node = Node(corpus, operation, args)
267                 node_stack.append(node)
268             return node_stack[0]
269
270         output_stack = []
271         operator_stack = []
272         for token in lex(query):
273             if not is_operator(token):
274                 output_stack.append(token)
275                 continue
276
277             # token is an operator...
278             if token == "(":
279                 operator_stack.append(token)
280             elif token == ")":
281                 ok = False
282                 while len(operator_stack) > 0:
283                     pop_operator = operator_stack.pop()
284                     if pop_operator != "(":
285                         output_stack.append(pop_operator)
286                     else:
287                         ok = True
288                         break
289                 if not ok:
290                     raise ParseError(
291                         "Unbalanced parenthesis in query expression"
292                     )
293
294             # and, or, not
295             else:
296                 my_precedence = operator_precedence(token)
297                 if my_precedence is None:
298                     raise ParseError(f"Unknown operator: {token}")
299                 while len(operator_stack) > 0:
300                     peek_operator = operator_stack[-1]
301                     if not is_operator(peek_operator) or peek_operator == "(":
302                         break
303                     peek_precedence = operator_precedence(peek_operator)
304                     if peek_precedence is None:
305                         raise ParseError("Internal error")
306                     if (
307                         (peek_precedence < my_precedence)
308                         or (peek_precedence == my_precedence)
309                         and (peek_operator not in and_or)
310                     ):
311                         break
312                     output_stack.append(operator_stack.pop())
313                 operator_stack.append(token)
314         while len(operator_stack) > 0:
315             token = operator_stack.pop()
316             if token in parens:
317                 raise ParseError("Unbalanced parenthesis in query expression")
318             output_stack.append(token)
319         return evaluate(self, output_stack)
320
321
322 class Node(object):
323     """A query AST node."""
324
325     def __init__(
326         self,
327         corpus: Corpus,
328         op: Operation,
329         operands: Sequence[Union[Node, str]],
330     ):
331         self.corpus = corpus
332         self.op = op
333         self.operands = operands
334
335     def eval(self) -> Set[str]:
336         """Evaluate this node."""
337
338         evaled_operands: List[Union[Set[str], str]] = []
339         for operand in self.operands:
340             if isinstance(operand, Node):
341                 evaled_operands.append(operand.eval())
342             elif isinstance(operand, str):
343                 evaled_operands.append(operand)
344             else:
345                 raise ParseError(f"Unexpected operand: {operand}")
346
347         retval = set()
348         if self.op is Operation.QUERY:
349             for tag in evaled_operands:
350                 if isinstance(tag, str):
351                     if ":" in tag:
352                         try:
353                             key, value = tag.split(":")
354                         except ValueError as v:
355                             raise ParseError(
356                                 f'Invalid key:value syntax at "{tag}"'
357                             ) from v
358                         if value == "*":
359                             r = self.corpus.get_docids_with_property(key)
360                         else:
361                             r = self.corpus.get_docids_by_property(key, value)
362                     else:
363                         r = self.corpus.get_docids_by_exact_tag(tag)
364                     retval.update(r)
365                 else:
366                     raise ParseError(f"Unexpected query {tag}")
367         elif self.op is Operation.DISJUNCTION:
368             if len(evaled_operands) != 2:
369                 raise ParseError(
370                     "Operation.DISJUNCTION (or) expects two operands."
371                 )
372             retval.update(evaled_operands[0])
373             retval.update(evaled_operands[1])
374         elif self.op is Operation.CONJUNCTION:
375             if len(evaled_operands) != 2:
376                 raise ParseError(
377                     "Operation.CONJUNCTION (and) expects two operands."
378                 )
379             retval.update(evaled_operands[0])
380             retval = retval.intersection(evaled_operands[1])
381         elif self.op is Operation.INVERSION:
382             if len(evaled_operands) != 1:
383                 raise ParseError(
384                     "Operation.INVERSION (not) expects one operand."
385                 )
386             _ = evaled_operands[0]
387             if isinstance(_, set):
388                 retval.update(self.corpus.invert_docid_set(_))
389             else:
390                 raise ParseError(f"Unexpected negation operand {_} ({type(_)})")
391         return retval
392
393
394 if __name__ == '__main__':
395     import doctest
396     doctest.testmod()