Used isort to sort imports. Also added to the git pre-commit hook.
[python_utils.git] / logical_search.py
1 #!/usr/bin/env python3
2
3 from __future__ import annotations
4
5 import enum
6 import sys
7 from collections import defaultdict
8 from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple, Union
9
10
11 class ParseError(Exception):
12     """An error encountered while parsing a logical search expression."""
13
14     def __init__(self, message: str):
15         self.message = message
16
17
18 class Document(NamedTuple):
19     """A tuple representing a searchable document."""
20
21     docid: str  # a unique idenfier for the document
22     tags: Set[str]  # an optional set of tags
23     properties: List[Tuple[str, str]]  # an optional set of key->value properties
24     reference: Any  # an optional reference to something else
25
26
27 class Operation(enum.Enum):
28     """A logical search query operation."""
29
30     QUERY = 1
31     CONJUNCTION = 2
32     DISJUNCTION = 3
33     INVERSION = 4
34
35     @staticmethod
36     def from_token(token: str):
37         table = {
38             "not": Operation.INVERSION,
39             "and": Operation.CONJUNCTION,
40             "or": Operation.DISJUNCTION,
41         }
42         return table.get(token, None)
43
44     def num_operands(self) -> Optional[int]:
45         table = {
46             Operation.INVERSION: 1,
47             Operation.CONJUNCTION: 2,
48             Operation.DISJUNCTION: 2,
49         }
50         return table.get(self, None)
51
52
53 class Corpus(object):
54     """A collection of searchable documents.
55
56     >>> c = Corpus()
57     >>> c.add_doc(Document(
58     ...                    docid=1,
59     ...                    tags=set(['urgent', 'important']),
60     ...                    properties=[
61     ...                                ('author', 'Scott'),
62     ...                                ('subject', 'your anniversary')
63     ...                    ],
64     ...                    reference=None,
65     ...                   )
66     ...          )
67     >>> c.add_doc(Document(
68     ...                    docid=2,
69     ...                    tags=set(['important']),
70     ...                    properties=[
71     ...                                ('author', 'Joe'),
72     ...                                ('subject', 'your performance at work')
73     ...                    ],
74     ...                    reference=None,
75     ...                   )
76     ...          )
77     >>> c.add_doc(Document(
78     ...                    docid=3,
79     ...                    tags=set(['urgent']),
80     ...                    properties=[
81     ...                                ('author', 'Scott'),
82     ...                                ('subject', 'car turning in front of you')
83     ...                    ],
84     ...                    reference=None,
85     ...                   )
86     ...          )
87     >>> c.query('author:Scott and important')
88     {1}
89     """
90
91     def __init__(self) -> None:
92         self.docids_by_tag: Dict[str, Set[str]] = defaultdict(set)
93         self.docids_by_property: Dict[Tuple[str, str], Set[str]] = defaultdict(set)
94         self.docids_with_property: Dict[str, Set[str]] = defaultdict(set)
95         self.documents_by_docid: Dict[str, Document] = {}
96
97     def add_doc(self, doc: Document) -> None:
98         """Add a new Document to the Corpus.  Each Document must have a
99         distinct docid that will serve as its primary identifier.  If
100         the same Document is added multiple times, only the most
101         recent addition is indexed.  If two distinct documents with
102         the same docid are added, the latter klobbers the former in the
103         indexes.
104
105         Each Document may have an optional set of tags which can be
106         used later in expressions to the query method.
107
108         Each Document may have an optional list of key->value tuples
109         which can be used later in expressions to the query method.
110
111         Document includes a user-defined "reference" field which is
112         never interpreted by this module.  This is meant to allow easy
113         mapping between Documents in this corpus and external objects
114         they may represent.
115         """
116
117         if doc.docid in self.documents_by_docid:
118             # Handle collisions; assume that we are re-indexing the
119             # same document so remove it from the indexes before
120             # adding it back again.
121             colliding_doc = self.documents_by_docid[doc.docid]
122             assert colliding_doc.docid == doc.docid
123             del self.documents_by_docid[doc.docid]
124             for tag in colliding_doc.tags:
125                 self.docids_by_tag[tag].remove(doc.docid)
126             for key, value in colliding_doc.properties:
127                 self.docids_by_property[(key, value)].remove(doc.docid)
128                 self.docids_with_property[key].remove(doc.docid)
129
130         # Index the new Document
131         assert doc.docid not in self.documents_by_docid
132         self.documents_by_docid[doc.docid] = doc
133         for tag in doc.tags:
134             self.docids_by_tag[tag].add(doc.docid)
135         for key, value in doc.properties:
136             self.docids_by_property[(key, value)].add(doc.docid)
137             self.docids_with_property[key].add(doc.docid)
138
139     def get_docids_by_exact_tag(self, tag: str) -> Set[str]:
140         """Return the set of docids that have a particular tag."""
141
142         return self.docids_by_tag[tag]
143
144     def get_docids_by_searching_tags(self, tag: str) -> Set[str]:
145         """Return the set of docids with a tag that contains a str"""
146
147         ret = set()
148         for search_tag in self.docids_by_tag:
149             if tag in search_tag:
150                 for docid in self.docids_by_tag[search_tag]:
151                     ret.add(docid)
152         return ret
153
154     def get_docids_with_property(self, key: str) -> Set[str]:
155         """Return the set of docids that have a particular property no matter
156         what that property's value.
157
158         """
159         return self.docids_with_property[key]
160
161     def get_docids_by_property(self, key: str, value: str) -> Set[str]:
162         """Return the set of docids that have a particular property with a
163         particular value..
164
165         """
166         return self.docids_by_property[(key, value)]
167
168     def invert_docid_set(self, original: Set[str]) -> Set[str]:
169         """Invert a set of docids."""
170
171         return set(
172             [docid for docid in self.documents_by_docid.keys() if docid not in original]
173         )
174
175     def get_doc(self, docid: str) -> Optional[Document]:
176         """Given a docid, retrieve the previously added Document."""
177
178         return self.documents_by_docid.get(docid, None)
179
180     def query(self, query: str) -> Optional[Set[str]]:
181         """Query the corpus for documents that match a logical expression.
182         Returns a (potentially empty) set of docids for the matching
183         (previously added) documents or None on error.
184
185         e.g.
186
187         tag1 and tag2 and not tag3
188
189         (tag1 or tag2) and (tag3 or tag4)
190
191         (tag1 and key2:value2) or (tag2 and key1:value1)
192
193         key:*
194
195         tag1 and key:*
196         """
197
198         try:
199             root = self._parse_query(query)
200         except ParseError as e:
201             print(e.message, file=sys.stderr)
202             return None
203         return root.eval()
204
205     def _parse_query(self, query: str):
206         """Internal parse helper; prefer to use query instead."""
207
208         parens = set(["(", ")"])
209         and_or = set(["and", "or"])
210
211         def operator_precedence(token: str) -> Optional[int]:
212             table = {
213                 "(": 4,  # higher
214                 ")": 4,
215                 "not": 3,
216                 "and": 2,
217                 "or": 1,  # lower
218             }
219             return table.get(token, None)
220
221         def is_operator(token: str) -> bool:
222             return operator_precedence(token) is not None
223
224         def lex(query: str):
225             tokens = query.split()
226             for token in tokens:
227                 # Handle ( and ) operators stuck to the ends of tokens
228                 # that split() doesn't understand.
229                 if len(token) > 1:
230                     first = token[0]
231                     if first in parens:
232                         tail = token[1:]
233                         yield first
234                         token = tail
235                     last = token[-1]
236                     if last in parens:
237                         head = token[0:-1]
238                         yield head
239                         token = last
240                 yield token
241
242         def evaluate(corpus: Corpus, stack: List[str]):
243             node_stack: List[Node] = []
244             for token in stack:
245                 node = None
246                 if not is_operator(token):
247                     node = Node(corpus, Operation.QUERY, [token])
248                 else:
249                     args = []
250                     operation = Operation.from_token(token)
251                     operand_count = operation.num_operands()
252                     if len(node_stack) < operand_count:
253                         raise ParseError(
254                             f"Incorrect number of operations for {operation}"
255                         )
256                     for _ in range(operation.num_operands()):
257                         args.append(node_stack.pop())
258                     node = Node(corpus, operation, args)
259                 node_stack.append(node)
260             return node_stack[0]
261
262         output_stack = []
263         operator_stack = []
264         for token in lex(query):
265             if not is_operator(token):
266                 output_stack.append(token)
267                 continue
268
269             # token is an operator...
270             if token == "(":
271                 operator_stack.append(token)
272             elif token == ")":
273                 ok = False
274                 while len(operator_stack) > 0:
275                     pop_operator = operator_stack.pop()
276                     if pop_operator != "(":
277                         output_stack.append(pop_operator)
278                     else:
279                         ok = True
280                         break
281                 if not ok:
282                     raise ParseError("Unbalanced parenthesis in query expression")
283
284             # and, or, not
285             else:
286                 my_precedence = operator_precedence(token)
287                 if my_precedence is None:
288                     raise ParseError(f"Unknown operator: {token}")
289                 while len(operator_stack) > 0:
290                     peek_operator = operator_stack[-1]
291                     if not is_operator(peek_operator) or peek_operator == "(":
292                         break
293                     peek_precedence = operator_precedence(peek_operator)
294                     if peek_precedence is None:
295                         raise ParseError("Internal error")
296                     if (
297                         (peek_precedence < my_precedence)
298                         or (peek_precedence == my_precedence)
299                         and (peek_operator not in and_or)
300                     ):
301                         break
302                     output_stack.append(operator_stack.pop())
303                 operator_stack.append(token)
304         while len(operator_stack) > 0:
305             token = operator_stack.pop()
306             if token in parens:
307                 raise ParseError("Unbalanced parenthesis in query expression")
308             output_stack.append(token)
309         return evaluate(self, output_stack)
310
311
312 class Node(object):
313     """A query AST node."""
314
315     def __init__(
316         self,
317         corpus: Corpus,
318         op: Operation,
319         operands: Sequence[Union[Node, str]],
320     ):
321         self.corpus = corpus
322         self.op = op
323         self.operands = operands
324
325     def eval(self) -> Set[str]:
326         """Evaluate this node."""
327
328         evaled_operands: List[Union[Set[str], str]] = []
329         for operand in self.operands:
330             if isinstance(operand, Node):
331                 evaled_operands.append(operand.eval())
332             elif isinstance(operand, str):
333                 evaled_operands.append(operand)
334             else:
335                 raise ParseError(f"Unexpected operand: {operand}")
336
337         retval = set()
338         if self.op is Operation.QUERY:
339             for tag in evaled_operands:
340                 if isinstance(tag, str):
341                     if ":" in tag:
342                         try:
343                             key, value = tag.split(":")
344                         except ValueError as v:
345                             raise ParseError(
346                                 f'Invalid key:value syntax at "{tag}"'
347                             ) from v
348                         if value == "*":
349                             r = self.corpus.get_docids_with_property(key)
350                         else:
351                             r = self.corpus.get_docids_by_property(key, value)
352                     else:
353                         r = self.corpus.get_docids_by_exact_tag(tag)
354                     retval.update(r)
355                 else:
356                     raise ParseError(f"Unexpected query {tag}")
357         elif self.op is Operation.DISJUNCTION:
358             if len(evaled_operands) != 2:
359                 raise ParseError("Operation.DISJUNCTION (or) expects two operands.")
360             retval.update(evaled_operands[0])
361             retval.update(evaled_operands[1])
362         elif self.op is Operation.CONJUNCTION:
363             if len(evaled_operands) != 2:
364                 raise ParseError("Operation.CONJUNCTION (and) expects two operands.")
365             retval.update(evaled_operands[0])
366             retval = retval.intersection(evaled_operands[1])
367         elif self.op is Operation.INVERSION:
368             if len(evaled_operands) != 1:
369                 raise ParseError("Operation.INVERSION (not) expects one operand.")
370             _ = evaled_operands[0]
371             if isinstance(_, set):
372                 retval.update(self.corpus.invert_docid_set(_))
373             else:
374                 raise ParseError(f"Unexpected negation operand {_} ({type(_)})")
375         return retval
376
377
378 if __name__ == '__main__':
379     import doctest
380
381     doctest.testmod()