Change settings in flake8 and black.
[python_utils.git] / logical_search.py
1 #!/usr/bin/env python3
2
3 from __future__ import annotations
4
5 import enum
6 import sys
7 from collections import defaultdict
8 from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple, Union
9
10
11 class ParseError(Exception):
12     """An error encountered while parsing a logical search expression."""
13
14     def __init__(self, message: str):
15         self.message = message
16
17
18 class Document(NamedTuple):
19     """A tuple representing a searchable document."""
20
21     docid: str  # a unique idenfier for the document
22     tags: Set[str]  # an optional set of tags
23     properties: List[Tuple[str, str]]  # an optional set of key->value properties
24     reference: Any  # an optional reference to something else
25
26
27 class Operation(enum.Enum):
28     """A logical search query operation."""
29
30     QUERY = 1
31     CONJUNCTION = 2
32     DISJUNCTION = 3
33     INVERSION = 4
34
35     @staticmethod
36     def from_token(token: str):
37         table = {
38             "not": Operation.INVERSION,
39             "and": Operation.CONJUNCTION,
40             "or": Operation.DISJUNCTION,
41         }
42         return table.get(token, None)
43
44     def num_operands(self) -> Optional[int]:
45         table = {
46             Operation.INVERSION: 1,
47             Operation.CONJUNCTION: 2,
48             Operation.DISJUNCTION: 2,
49         }
50         return table.get(self, None)
51
52
53 class Corpus(object):
54     """A collection of searchable documents.
55
56     >>> c = Corpus()
57     >>> c.add_doc(Document(
58     ...                    docid=1,
59     ...                    tags=set(['urgent', 'important']),
60     ...                    properties=[
61     ...                                ('author', 'Scott'),
62     ...                                ('subject', 'your anniversary')
63     ...                    ],
64     ...                    reference=None,
65     ...                   )
66     ...          )
67     >>> c.add_doc(Document(
68     ...                    docid=2,
69     ...                    tags=set(['important']),
70     ...                    properties=[
71     ...                                ('author', 'Joe'),
72     ...                                ('subject', 'your performance at work')
73     ...                    ],
74     ...                    reference=None,
75     ...                   )
76     ...          )
77     >>> c.add_doc(Document(
78     ...                    docid=3,
79     ...                    tags=set(['urgent']),
80     ...                    properties=[
81     ...                                ('author', 'Scott'),
82     ...                                ('subject', 'car turning in front of you')
83     ...                    ],
84     ...                    reference=None,
85     ...                   )
86     ...          )
87     >>> c.query('author:Scott and important')
88     {1}
89     """
90
91     def __init__(self) -> None:
92         self.docids_by_tag: Dict[str, Set[str]] = defaultdict(set)
93         self.docids_by_property: Dict[Tuple[str, str], Set[str]] = defaultdict(set)
94         self.docids_with_property: Dict[str, Set[str]] = defaultdict(set)
95         self.documents_by_docid: Dict[str, Document] = {}
96
97     def add_doc(self, doc: Document) -> None:
98         """Add a new Document to the Corpus.  Each Document must have a
99         distinct docid that will serve as its primary identifier.  If
100         the same Document is added multiple times, only the most
101         recent addition is indexed.  If two distinct documents with
102         the same docid are added, the latter klobbers the former in the
103         indexes.
104
105         Each Document may have an optional set of tags which can be
106         used later in expressions to the query method.
107
108         Each Document may have an optional list of key->value tuples
109         which can be used later in expressions to the query method.
110
111         Document includes a user-defined "reference" field which is
112         never interpreted by this module.  This is meant to allow easy
113         mapping between Documents in this corpus and external objects
114         they may represent.
115         """
116
117         if doc.docid in self.documents_by_docid:
118             # Handle collisions; assume that we are re-indexing the
119             # same document so remove it from the indexes before
120             # adding it back again.
121             colliding_doc = self.documents_by_docid[doc.docid]
122             assert colliding_doc.docid == doc.docid
123             del self.documents_by_docid[doc.docid]
124             for tag in colliding_doc.tags:
125                 self.docids_by_tag[tag].remove(doc.docid)
126             for key, value in colliding_doc.properties:
127                 self.docids_by_property[(key, value)].remove(doc.docid)
128                 self.docids_with_property[key].remove(doc.docid)
129
130         # Index the new Document
131         assert doc.docid not in self.documents_by_docid
132         self.documents_by_docid[doc.docid] = doc
133         for tag in doc.tags:
134             self.docids_by_tag[tag].add(doc.docid)
135         for key, value in doc.properties:
136             self.docids_by_property[(key, value)].add(doc.docid)
137             self.docids_with_property[key].add(doc.docid)
138
139     def get_docids_by_exact_tag(self, tag: str) -> Set[str]:
140         """Return the set of docids that have a particular tag."""
141
142         return self.docids_by_tag[tag]
143
144     def get_docids_by_searching_tags(self, tag: str) -> Set[str]:
145         """Return the set of docids with a tag that contains a str"""
146
147         ret = set()
148         for search_tag in self.docids_by_tag:
149             if tag in search_tag:
150                 for docid in self.docids_by_tag[search_tag]:
151                     ret.add(docid)
152         return ret
153
154     def get_docids_with_property(self, key: str) -> Set[str]:
155         """Return the set of docids that have a particular property no matter
156         what that property's value.
157
158         """
159         return self.docids_with_property[key]
160
161     def get_docids_by_property(self, key: str, value: str) -> Set[str]:
162         """Return the set of docids that have a particular property with a
163         particular value..
164
165         """
166         return self.docids_by_property[(key, value)]
167
168     def invert_docid_set(self, original: Set[str]) -> Set[str]:
169         """Invert a set of docids."""
170
171         return set([docid for docid in self.documents_by_docid.keys() if docid not in original])
172
173     def get_doc(self, docid: str) -> Optional[Document]:
174         """Given a docid, retrieve the previously added Document."""
175
176         return self.documents_by_docid.get(docid, None)
177
178     def query(self, query: str) -> Optional[Set[str]]:
179         """Query the corpus for documents that match a logical expression.
180         Returns a (potentially empty) set of docids for the matching
181         (previously added) documents or None on error.
182
183         e.g.
184
185         tag1 and tag2 and not tag3
186
187         (tag1 or tag2) and (tag3 or tag4)
188
189         (tag1 and key2:value2) or (tag2 and key1:value1)
190
191         key:*
192
193         tag1 and key:*
194         """
195
196         try:
197             root = self._parse_query(query)
198         except ParseError as e:
199             print(e.message, file=sys.stderr)
200             return None
201         return root.eval()
202
203     def _parse_query(self, query: str):
204         """Internal parse helper; prefer to use query instead."""
205
206         parens = set(["(", ")"])
207         and_or = set(["and", "or"])
208
209         def operator_precedence(token: str) -> Optional[int]:
210             table = {
211                 "(": 4,  # higher
212                 ")": 4,
213                 "not": 3,
214                 "and": 2,
215                 "or": 1,  # lower
216             }
217             return table.get(token, None)
218
219         def is_operator(token: str) -> bool:
220             return operator_precedence(token) is not None
221
222         def lex(query: str):
223             tokens = query.split()
224             for token in tokens:
225                 # Handle ( and ) operators stuck to the ends of tokens
226                 # that split() doesn't understand.
227                 if len(token) > 1:
228                     first = token[0]
229                     if first in parens:
230                         tail = token[1:]
231                         yield first
232                         token = tail
233                     last = token[-1]
234                     if last in parens:
235                         head = token[0:-1]
236                         yield head
237                         token = last
238                 yield token
239
240         def evaluate(corpus: Corpus, stack: List[str]):
241             node_stack: List[Node] = []
242             for token in stack:
243                 node = None
244                 if not is_operator(token):
245                     node = Node(corpus, Operation.QUERY, [token])
246                 else:
247                     args = []
248                     operation = Operation.from_token(token)
249                     operand_count = operation.num_operands()
250                     if len(node_stack) < operand_count:
251                         raise ParseError(f"Incorrect number of operations for {operation}")
252                     for _ in range(operation.num_operands()):
253                         args.append(node_stack.pop())
254                     node = Node(corpus, operation, args)
255                 node_stack.append(node)
256             return node_stack[0]
257
258         output_stack = []
259         operator_stack = []
260         for token in lex(query):
261             if not is_operator(token):
262                 output_stack.append(token)
263                 continue
264
265             # token is an operator...
266             if token == "(":
267                 operator_stack.append(token)
268             elif token == ")":
269                 ok = False
270                 while len(operator_stack) > 0:
271                     pop_operator = operator_stack.pop()
272                     if pop_operator != "(":
273                         output_stack.append(pop_operator)
274                     else:
275                         ok = True
276                         break
277                 if not ok:
278                     raise ParseError("Unbalanced parenthesis in query expression")
279
280             # and, or, not
281             else:
282                 my_precedence = operator_precedence(token)
283                 if my_precedence is None:
284                     raise ParseError(f"Unknown operator: {token}")
285                 while len(operator_stack) > 0:
286                     peek_operator = operator_stack[-1]
287                     if not is_operator(peek_operator) or peek_operator == "(":
288                         break
289                     peek_precedence = operator_precedence(peek_operator)
290                     if peek_precedence is None:
291                         raise ParseError("Internal error")
292                     if (
293                         (peek_precedence < my_precedence)
294                         or (peek_precedence == my_precedence)
295                         and (peek_operator not in and_or)
296                     ):
297                         break
298                     output_stack.append(operator_stack.pop())
299                 operator_stack.append(token)
300         while len(operator_stack) > 0:
301             token = operator_stack.pop()
302             if token in parens:
303                 raise ParseError("Unbalanced parenthesis in query expression")
304             output_stack.append(token)
305         return evaluate(self, output_stack)
306
307
308 class Node(object):
309     """A query AST node."""
310
311     def __init__(
312         self,
313         corpus: Corpus,
314         op: Operation,
315         operands: Sequence[Union[Node, str]],
316     ):
317         self.corpus = corpus
318         self.op = op
319         self.operands = operands
320
321     def eval(self) -> Set[str]:
322         """Evaluate this node."""
323
324         evaled_operands: List[Union[Set[str], str]] = []
325         for operand in self.operands:
326             if isinstance(operand, Node):
327                 evaled_operands.append(operand.eval())
328             elif isinstance(operand, str):
329                 evaled_operands.append(operand)
330             else:
331                 raise ParseError(f"Unexpected operand: {operand}")
332
333         retval = set()
334         if self.op is Operation.QUERY:
335             for tag in evaled_operands:
336                 if isinstance(tag, str):
337                     if ":" in tag:
338                         try:
339                             key, value = tag.split(":")
340                         except ValueError as v:
341                             raise ParseError(f'Invalid key:value syntax at "{tag}"') from v
342                         if value == "*":
343                             r = self.corpus.get_docids_with_property(key)
344                         else:
345                             r = self.corpus.get_docids_by_property(key, value)
346                     else:
347                         r = self.corpus.get_docids_by_exact_tag(tag)
348                     retval.update(r)
349                 else:
350                     raise ParseError(f"Unexpected query {tag}")
351         elif self.op is Operation.DISJUNCTION:
352             if len(evaled_operands) != 2:
353                 raise ParseError("Operation.DISJUNCTION (or) expects two operands.")
354             retval.update(evaled_operands[0])
355             retval.update(evaled_operands[1])
356         elif self.op is Operation.CONJUNCTION:
357             if len(evaled_operands) != 2:
358                 raise ParseError("Operation.CONJUNCTION (and) expects two operands.")
359             retval.update(evaled_operands[0])
360             retval = retval.intersection(evaled_operands[1])
361         elif self.op is Operation.INVERSION:
362             if len(evaled_operands) != 1:
363                 raise ParseError("Operation.INVERSION (not) expects one operand.")
364             _ = evaled_operands[0]
365             if isinstance(_, set):
366                 retval.update(self.corpus.invert_docid_set(_))
367             else:
368                 raise ParseError(f"Unexpected negation operand {_} ({type(_)})")
369         return retval
370
371
372 if __name__ == '__main__':
373     import doctest
374
375     doctest.testmod()