3ebaee5652040bef7d67620fb0e69f028c55c986
[python_utils.git] / logical_search.py
1 #!/usr/bin/env python3
2
3 from __future__ import annotations
4
5 from collections import defaultdict
6 import enum
7 import sys
8 from typing import (
9     Any,
10     Dict,
11     List,
12     NamedTuple,
13     Optional,
14     Set,
15     Sequence,
16     Tuple,
17     Union,
18 )
19
20
21 class ParseError(Exception):
22     """An error encountered while parsing a logical search expression."""
23
24     def __init__(self, message: str):
25         self.message = message
26
27
28 class Document(NamedTuple):
29     """A tuple representing a searchable document."""
30
31     docid: str  # a unique idenfier for the document
32     tags: Set[str]  # an optional set of tags
33     properties: List[
34         Tuple[str, str]
35     ]  # an optional set of key->value properties
36     reference: Any  # an optional reference to something else
37
38
39 class Operation(enum.Enum):
40     """A logical search query operation."""
41
42     QUERY = 1
43     CONJUNCTION = 2
44     DISJUNCTION = 3
45     INVERSION = 4
46
47     @staticmethod
48     def from_token(token: str):
49         table = {
50             "not": Operation.INVERSION,
51             "and": Operation.CONJUNCTION,
52             "or": Operation.DISJUNCTION,
53         }
54         return table.get(token, None)
55
56     def num_operands(self) -> Optional[int]:
57         table = {
58             Operation.INVERSION: 1,
59             Operation.CONJUNCTION: 2,
60             Operation.DISJUNCTION: 2,
61         }
62         return table.get(self, None)
63
64
65 class Corpus(object):
66     """A collection of searchable documents.
67
68     >>> c = Corpus()
69     >>> c.add_doc(Document(
70     ...                    docid=1,
71     ...                    tags=set(['urgent', 'important']),
72     ...                    properties=[
73     ...                                ('author', 'Scott'),
74     ...                                ('subject', 'your anniversary')
75     ...                    ],
76     ...                    reference=None,
77     ...                   )
78     ...          )
79     >>> c.add_doc(Document(
80     ...                    docid=2,
81     ...                    tags=set(['important']),
82     ...                    properties=[
83     ...                                ('author', 'Joe'),
84     ...                                ('subject', 'your performance at work')
85     ...                    ],
86     ...                    reference=None,
87     ...                   )
88     ...          )
89     >>> c.add_doc(Document(
90     ...                    docid=3,
91     ...                    tags=set(['urgent']),
92     ...                    properties=[
93     ...                                ('author', 'Scott'),
94     ...                                ('subject', 'car turning in front of you')
95     ...                    ],
96     ...                    reference=None,
97     ...                   )
98     ...          )
99     >>> c.query('author:Scott and important')
100     {1}
101     """
102
103     def __init__(self) -> None:
104         self.docids_by_tag: Dict[str, Set[str]] = defaultdict(set)
105         self.docids_by_property: Dict[Tuple[str, str], Set[str]] = defaultdict(
106             set
107         )
108         self.docids_with_property: Dict[str, Set[str]] = defaultdict(set)
109         self.documents_by_docid: Dict[str, Document] = {}
110
111     def add_doc(self, doc: Document) -> None:
112         """Add a new Document to the Corpus.  Each Document must have a
113         distinct docid that will serve as its primary identifier.  If
114         the same Document is added multiple times, only the most
115         recent addition is indexed.  If two distinct documents with
116         the same docid are added, the latter klobbers the former in the
117         indexes.
118
119         Each Document may have an optional set of tags which can be
120         used later in expressions to the query method.
121
122         Each Document may have an optional list of key->value tuples
123         which can be used later in expressions to the query method.
124
125         Document includes a user-defined "reference" field which is
126         never interpreted by this module.  This is meant to allow easy
127         mapping between Documents in this corpus and external objects
128         they may represent.
129         """
130
131         if doc.docid in self.documents_by_docid:
132             # Handle collisions; assume that we are re-indexing the
133             # same document so remove it from the indexes before
134             # adding it back again.
135             colliding_doc = self.documents_by_docid[doc.docid]
136             assert colliding_doc.docid == doc.docid
137             del self.documents_by_docid[doc.docid]
138             for tag in colliding_doc.tags:
139                 self.docids_by_tag[tag].remove(doc.docid)
140             for key, value in colliding_doc.properties:
141                 self.docids_by_property[(key, value)].remove(doc.docid)
142                 self.docids_with_property[key].remove(doc.docid)
143
144         # Index the new Document
145         assert doc.docid not in self.documents_by_docid
146         self.documents_by_docid[doc.docid] = doc
147         for tag in doc.tags:
148             self.docids_by_tag[tag].add(doc.docid)
149         for key, value in doc.properties:
150             self.docids_by_property[(key, value)].add(doc.docid)
151             self.docids_with_property[key].add(doc.docid)
152
153     def get_docids_by_exact_tag(self, tag: str) -> Set[str]:
154         """Return the set of docids that have a particular tag."""
155
156         return self.docids_by_tag[tag]
157
158     def get_docids_by_searching_tags(self, tag: str) -> Set[str]:
159         """Return the set of docids with a tag that contains a str"""
160
161         ret = set()
162         for search_tag in self.docids_by_tag:
163             if tag in search_tag:
164                 for docid in self.docids_by_tag[search_tag]:
165                     ret.add(docid)
166         return ret
167
168     def get_docids_with_property(self, key: str) -> Set[str]:
169         """Return the set of docids that have a particular property no matter
170         what that property's value.
171
172         """
173         return self.docids_with_property[key]
174
175     def get_docids_by_property(self, key: str, value: str) -> Set[str]:
176         """Return the set of docids that have a particular property with a
177         particular value..
178
179         """
180         return self.docids_by_property[(key, value)]
181
182     def invert_docid_set(self, original: Set[str]) -> Set[str]:
183         """Invert a set of docids."""
184
185         return set(
186             [
187                 docid
188                 for docid in self.documents_by_docid.keys()
189                 if docid not in original
190             ]
191         )
192
193     def get_doc(self, docid: str) -> Optional[Document]:
194         """Given a docid, retrieve the previously added Document."""
195
196         return self.documents_by_docid.get(docid, None)
197
198     def query(self, query: str) -> Optional[Set[str]]:
199         """Query the corpus for documents that match a logical expression.
200         Returns a (potentially empty) set of docids for the matching
201         (previously added) documents or None on error.
202
203         e.g.
204
205         tag1 and tag2 and not tag3
206
207         (tag1 or tag2) and (tag3 or tag4)
208
209         (tag1 and key2:value2) or (tag2 and key1:value1)
210
211         key:*
212
213         tag1 and key:*
214         """
215
216         try:
217             root = self._parse_query(query)
218         except ParseError as e:
219             print(e.message, file=sys.stderr)
220             return None
221         return root.eval()
222
223     def _parse_query(self, query: str):
224         """Internal parse helper; prefer to use query instead."""
225
226         parens = set(["(", ")"])
227         and_or = set(["and", "or"])
228
229         def operator_precedence(token: str) -> Optional[int]:
230             table = {
231                 "(": 4,  # higher
232                 ")": 4,
233                 "not": 3,
234                 "and": 2,
235                 "or": 1,  # lower
236             }
237             return table.get(token, None)
238
239         def is_operator(token: str) -> bool:
240             return operator_precedence(token) is not None
241
242         def lex(query: str):
243             tokens = query.split()
244             for token in tokens:
245                 # Handle ( and ) operators stuck to the ends of tokens
246                 # that split() doesn't understand.
247                 if len(token) > 1:
248                     first = token[0]
249                     if first in parens:
250                         tail = token[1:]
251                         yield first
252                         token = tail
253                     last = token[-1]
254                     if last in parens:
255                         head = token[0:-1]
256                         yield head
257                         token = last
258                 yield token
259
260         def evaluate(corpus: Corpus, stack: List[str]):
261             node_stack: List[Node] = []
262             for token in stack:
263                 node = None
264                 if not is_operator(token):
265                     node = Node(corpus, Operation.QUERY, [token])
266                 else:
267                     args = []
268                     operation = Operation.from_token(token)
269                     operand_count = operation.num_operands()
270                     if len(node_stack) < operand_count:
271                         raise ParseError(
272                             f"Incorrect number of operations for {operation}"
273                         )
274                     for _ in range(operation.num_operands()):
275                         args.append(node_stack.pop())
276                     node = Node(corpus, operation, args)
277                 node_stack.append(node)
278             return node_stack[0]
279
280         output_stack = []
281         operator_stack = []
282         for token in lex(query):
283             if not is_operator(token):
284                 output_stack.append(token)
285                 continue
286
287             # token is an operator...
288             if token == "(":
289                 operator_stack.append(token)
290             elif token == ")":
291                 ok = False
292                 while len(operator_stack) > 0:
293                     pop_operator = operator_stack.pop()
294                     if pop_operator != "(":
295                         output_stack.append(pop_operator)
296                     else:
297                         ok = True
298                         break
299                 if not ok:
300                     raise ParseError(
301                         "Unbalanced parenthesis in query expression"
302                     )
303
304             # and, or, not
305             else:
306                 my_precedence = operator_precedence(token)
307                 if my_precedence is None:
308                     raise ParseError(f"Unknown operator: {token}")
309                 while len(operator_stack) > 0:
310                     peek_operator = operator_stack[-1]
311                     if not is_operator(peek_operator) or peek_operator == "(":
312                         break
313                     peek_precedence = operator_precedence(peek_operator)
314                     if peek_precedence is None:
315                         raise ParseError("Internal error")
316                     if (
317                         (peek_precedence < my_precedence)
318                         or (peek_precedence == my_precedence)
319                         and (peek_operator not in and_or)
320                     ):
321                         break
322                     output_stack.append(operator_stack.pop())
323                 operator_stack.append(token)
324         while len(operator_stack) > 0:
325             token = operator_stack.pop()
326             if token in parens:
327                 raise ParseError("Unbalanced parenthesis in query expression")
328             output_stack.append(token)
329         return evaluate(self, output_stack)
330
331
332 class Node(object):
333     """A query AST node."""
334
335     def __init__(
336         self,
337         corpus: Corpus,
338         op: Operation,
339         operands: Sequence[Union[Node, str]],
340     ):
341         self.corpus = corpus
342         self.op = op
343         self.operands = operands
344
345     def eval(self) -> Set[str]:
346         """Evaluate this node."""
347
348         evaled_operands: List[Union[Set[str], str]] = []
349         for operand in self.operands:
350             if isinstance(operand, Node):
351                 evaled_operands.append(operand.eval())
352             elif isinstance(operand, str):
353                 evaled_operands.append(operand)
354             else:
355                 raise ParseError(f"Unexpected operand: {operand}")
356
357         retval = set()
358         if self.op is Operation.QUERY:
359             for tag in evaled_operands:
360                 if isinstance(tag, str):
361                     if ":" in tag:
362                         try:
363                             key, value = tag.split(":")
364                         except ValueError as v:
365                             raise ParseError(
366                                 f'Invalid key:value syntax at "{tag}"'
367                             ) from v
368                         if value == "*":
369                             r = self.corpus.get_docids_with_property(key)
370                         else:
371                             r = self.corpus.get_docids_by_property(key, value)
372                     else:
373                         r = self.corpus.get_docids_by_exact_tag(tag)
374                     retval.update(r)
375                 else:
376                     raise ParseError(f"Unexpected query {tag}")
377         elif self.op is Operation.DISJUNCTION:
378             if len(evaled_operands) != 2:
379                 raise ParseError(
380                     "Operation.DISJUNCTION (or) expects two operands."
381                 )
382             retval.update(evaled_operands[0])
383             retval.update(evaled_operands[1])
384         elif self.op is Operation.CONJUNCTION:
385             if len(evaled_operands) != 2:
386                 raise ParseError(
387                     "Operation.CONJUNCTION (and) expects two operands."
388                 )
389             retval.update(evaled_operands[0])
390             retval = retval.intersection(evaled_operands[1])
391         elif self.op is Operation.INVERSION:
392             if len(evaled_operands) != 1:
393                 raise ParseError(
394                     "Operation.INVERSION (not) expects one operand."
395                 )
396             _ = evaled_operands[0]
397             if isinstance(_, set):
398                 retval.update(self.corpus.invert_docid_set(_))
399             else:
400                 raise ParseError(f"Unexpected negation operand {_} ({type(_)})")
401         return retval
402
403
404 if __name__ == '__main__':
405     import doctest
406     doctest.testmod()