Ahem. Still running black?
[python_utils.git] / logical_search.py
1 #!/usr/bin/env python3
2
3 from __future__ import annotations
4
5 from collections import defaultdict
6 import enum
7 import sys
8 from typing import (
9     Any,
10     Dict,
11     List,
12     NamedTuple,
13     Optional,
14     Set,
15     Sequence,
16     Tuple,
17     Union,
18 )
19
20
21 class ParseError(Exception):
22     """An error encountered while parsing a logical search expression."""
23
24     def __init__(self, message: str):
25         self.message = message
26
27
28 class Document(NamedTuple):
29     """A tuple representing a searchable document."""
30
31     docid: str  # a unique idenfier for the document
32     tags: Set[str]  # an optional set of tags
33     properties: List[Tuple[str, str]]  # an optional set of key->value properties
34     reference: Any  # an optional reference to something else
35
36
37 class Operation(enum.Enum):
38     """A logical search query operation."""
39
40     QUERY = 1
41     CONJUNCTION = 2
42     DISJUNCTION = 3
43     INVERSION = 4
44
45     @staticmethod
46     def from_token(token: str):
47         table = {
48             "not": Operation.INVERSION,
49             "and": Operation.CONJUNCTION,
50             "or": Operation.DISJUNCTION,
51         }
52         return table.get(token, None)
53
54     def num_operands(self) -> Optional[int]:
55         table = {
56             Operation.INVERSION: 1,
57             Operation.CONJUNCTION: 2,
58             Operation.DISJUNCTION: 2,
59         }
60         return table.get(self, None)
61
62
63 class Corpus(object):
64     """A collection of searchable documents.
65
66     >>> c = Corpus()
67     >>> c.add_doc(Document(
68     ...                    docid=1,
69     ...                    tags=set(['urgent', 'important']),
70     ...                    properties=[
71     ...                                ('author', 'Scott'),
72     ...                                ('subject', 'your anniversary')
73     ...                    ],
74     ...                    reference=None,
75     ...                   )
76     ...          )
77     >>> c.add_doc(Document(
78     ...                    docid=2,
79     ...                    tags=set(['important']),
80     ...                    properties=[
81     ...                                ('author', 'Joe'),
82     ...                                ('subject', 'your performance at work')
83     ...                    ],
84     ...                    reference=None,
85     ...                   )
86     ...          )
87     >>> c.add_doc(Document(
88     ...                    docid=3,
89     ...                    tags=set(['urgent']),
90     ...                    properties=[
91     ...                                ('author', 'Scott'),
92     ...                                ('subject', 'car turning in front of you')
93     ...                    ],
94     ...                    reference=None,
95     ...                   )
96     ...          )
97     >>> c.query('author:Scott and important')
98     {1}
99     """
100
101     def __init__(self) -> None:
102         self.docids_by_tag: Dict[str, Set[str]] = defaultdict(set)
103         self.docids_by_property: Dict[Tuple[str, str], Set[str]] = defaultdict(set)
104         self.docids_with_property: Dict[str, Set[str]] = defaultdict(set)
105         self.documents_by_docid: Dict[str, Document] = {}
106
107     def add_doc(self, doc: Document) -> None:
108         """Add a new Document to the Corpus.  Each Document must have a
109         distinct docid that will serve as its primary identifier.  If
110         the same Document is added multiple times, only the most
111         recent addition is indexed.  If two distinct documents with
112         the same docid are added, the latter klobbers the former in the
113         indexes.
114
115         Each Document may have an optional set of tags which can be
116         used later in expressions to the query method.
117
118         Each Document may have an optional list of key->value tuples
119         which can be used later in expressions to the query method.
120
121         Document includes a user-defined "reference" field which is
122         never interpreted by this module.  This is meant to allow easy
123         mapping between Documents in this corpus and external objects
124         they may represent.
125         """
126
127         if doc.docid in self.documents_by_docid:
128             # Handle collisions; assume that we are re-indexing the
129             # same document so remove it from the indexes before
130             # adding it back again.
131             colliding_doc = self.documents_by_docid[doc.docid]
132             assert colliding_doc.docid == doc.docid
133             del self.documents_by_docid[doc.docid]
134             for tag in colliding_doc.tags:
135                 self.docids_by_tag[tag].remove(doc.docid)
136             for key, value in colliding_doc.properties:
137                 self.docids_by_property[(key, value)].remove(doc.docid)
138                 self.docids_with_property[key].remove(doc.docid)
139
140         # Index the new Document
141         assert doc.docid not in self.documents_by_docid
142         self.documents_by_docid[doc.docid] = doc
143         for tag in doc.tags:
144             self.docids_by_tag[tag].add(doc.docid)
145         for key, value in doc.properties:
146             self.docids_by_property[(key, value)].add(doc.docid)
147             self.docids_with_property[key].add(doc.docid)
148
149     def get_docids_by_exact_tag(self, tag: str) -> Set[str]:
150         """Return the set of docids that have a particular tag."""
151
152         return self.docids_by_tag[tag]
153
154     def get_docids_by_searching_tags(self, tag: str) -> Set[str]:
155         """Return the set of docids with a tag that contains a str"""
156
157         ret = set()
158         for search_tag in self.docids_by_tag:
159             if tag in search_tag:
160                 for docid in self.docids_by_tag[search_tag]:
161                     ret.add(docid)
162         return ret
163
164     def get_docids_with_property(self, key: str) -> Set[str]:
165         """Return the set of docids that have a particular property no matter
166         what that property's value.
167
168         """
169         return self.docids_with_property[key]
170
171     def get_docids_by_property(self, key: str, value: str) -> Set[str]:
172         """Return the set of docids that have a particular property with a
173         particular value..
174
175         """
176         return self.docids_by_property[(key, value)]
177
178     def invert_docid_set(self, original: Set[str]) -> Set[str]:
179         """Invert a set of docids."""
180
181         return set(
182             [docid for docid in self.documents_by_docid.keys() if docid not in original]
183         )
184
185     def get_doc(self, docid: str) -> Optional[Document]:
186         """Given a docid, retrieve the previously added Document."""
187
188         return self.documents_by_docid.get(docid, None)
189
190     def query(self, query: str) -> Optional[Set[str]]:
191         """Query the corpus for documents that match a logical expression.
192         Returns a (potentially empty) set of docids for the matching
193         (previously added) documents or None on error.
194
195         e.g.
196
197         tag1 and tag2 and not tag3
198
199         (tag1 or tag2) and (tag3 or tag4)
200
201         (tag1 and key2:value2) or (tag2 and key1:value1)
202
203         key:*
204
205         tag1 and key:*
206         """
207
208         try:
209             root = self._parse_query(query)
210         except ParseError as e:
211             print(e.message, file=sys.stderr)
212             return None
213         return root.eval()
214
215     def _parse_query(self, query: str):
216         """Internal parse helper; prefer to use query instead."""
217
218         parens = set(["(", ")"])
219         and_or = set(["and", "or"])
220
221         def operator_precedence(token: str) -> Optional[int]:
222             table = {
223                 "(": 4,  # higher
224                 ")": 4,
225                 "not": 3,
226                 "and": 2,
227                 "or": 1,  # lower
228             }
229             return table.get(token, None)
230
231         def is_operator(token: str) -> bool:
232             return operator_precedence(token) is not None
233
234         def lex(query: str):
235             tokens = query.split()
236             for token in tokens:
237                 # Handle ( and ) operators stuck to the ends of tokens
238                 # that split() doesn't understand.
239                 if len(token) > 1:
240                     first = token[0]
241                     if first in parens:
242                         tail = token[1:]
243                         yield first
244                         token = tail
245                     last = token[-1]
246                     if last in parens:
247                         head = token[0:-1]
248                         yield head
249                         token = last
250                 yield token
251
252         def evaluate(corpus: Corpus, stack: List[str]):
253             node_stack: List[Node] = []
254             for token in stack:
255                 node = None
256                 if not is_operator(token):
257                     node = Node(corpus, Operation.QUERY, [token])
258                 else:
259                     args = []
260                     operation = Operation.from_token(token)
261                     operand_count = operation.num_operands()
262                     if len(node_stack) < operand_count:
263                         raise ParseError(
264                             f"Incorrect number of operations for {operation}"
265                         )
266                     for _ in range(operation.num_operands()):
267                         args.append(node_stack.pop())
268                     node = Node(corpus, operation, args)
269                 node_stack.append(node)
270             return node_stack[0]
271
272         output_stack = []
273         operator_stack = []
274         for token in lex(query):
275             if not is_operator(token):
276                 output_stack.append(token)
277                 continue
278
279             # token is an operator...
280             if token == "(":
281                 operator_stack.append(token)
282             elif token == ")":
283                 ok = False
284                 while len(operator_stack) > 0:
285                     pop_operator = operator_stack.pop()
286                     if pop_operator != "(":
287                         output_stack.append(pop_operator)
288                     else:
289                         ok = True
290                         break
291                 if not ok:
292                     raise ParseError("Unbalanced parenthesis in query expression")
293
294             # and, or, not
295             else:
296                 my_precedence = operator_precedence(token)
297                 if my_precedence is None:
298                     raise ParseError(f"Unknown operator: {token}")
299                 while len(operator_stack) > 0:
300                     peek_operator = operator_stack[-1]
301                     if not is_operator(peek_operator) or peek_operator == "(":
302                         break
303                     peek_precedence = operator_precedence(peek_operator)
304                     if peek_precedence is None:
305                         raise ParseError("Internal error")
306                     if (
307                         (peek_precedence < my_precedence)
308                         or (peek_precedence == my_precedence)
309                         and (peek_operator not in and_or)
310                     ):
311                         break
312                     output_stack.append(operator_stack.pop())
313                 operator_stack.append(token)
314         while len(operator_stack) > 0:
315             token = operator_stack.pop()
316             if token in parens:
317                 raise ParseError("Unbalanced parenthesis in query expression")
318             output_stack.append(token)
319         return evaluate(self, output_stack)
320
321
322 class Node(object):
323     """A query AST node."""
324
325     def __init__(
326         self,
327         corpus: Corpus,
328         op: Operation,
329         operands: Sequence[Union[Node, str]],
330     ):
331         self.corpus = corpus
332         self.op = op
333         self.operands = operands
334
335     def eval(self) -> Set[str]:
336         """Evaluate this node."""
337
338         evaled_operands: List[Union[Set[str], str]] = []
339         for operand in self.operands:
340             if isinstance(operand, Node):
341                 evaled_operands.append(operand.eval())
342             elif isinstance(operand, str):
343                 evaled_operands.append(operand)
344             else:
345                 raise ParseError(f"Unexpected operand: {operand}")
346
347         retval = set()
348         if self.op is Operation.QUERY:
349             for tag in evaled_operands:
350                 if isinstance(tag, str):
351                     if ":" in tag:
352                         try:
353                             key, value = tag.split(":")
354                         except ValueError as v:
355                             raise ParseError(
356                                 f'Invalid key:value syntax at "{tag}"'
357                             ) from v
358                         if value == "*":
359                             r = self.corpus.get_docids_with_property(key)
360                         else:
361                             r = self.corpus.get_docids_by_property(key, value)
362                     else:
363                         r = self.corpus.get_docids_by_exact_tag(tag)
364                     retval.update(r)
365                 else:
366                     raise ParseError(f"Unexpected query {tag}")
367         elif self.op is Operation.DISJUNCTION:
368             if len(evaled_operands) != 2:
369                 raise ParseError("Operation.DISJUNCTION (or) expects two operands.")
370             retval.update(evaled_operands[0])
371             retval.update(evaled_operands[1])
372         elif self.op is Operation.CONJUNCTION:
373             if len(evaled_operands) != 2:
374                 raise ParseError("Operation.CONJUNCTION (and) expects two operands.")
375             retval.update(evaled_operands[0])
376             retval = retval.intersection(evaled_operands[1])
377         elif self.op is Operation.INVERSION:
378             if len(evaled_operands) != 1:
379                 raise ParseError("Operation.INVERSION (not) expects one operand.")
380             _ = evaled_operands[0]
381             if isinstance(_, set):
382                 retval.update(self.corpus.invert_docid_set(_))
383             else:
384                 raise ParseError(f"Unexpected negation operand {_} ({type(_)})")
385         return retval
386
387
388 if __name__ == '__main__':
389     import doctest
390
391     doctest.testmod()