More cleanup, yey!
[python_utils.git] / logical_search.py
1 #!/usr/bin/env python3
2
3 """This is a module concerned with the creation of and searching of a
4 corpus of documents.  The corpus is held in memory for fast
5 searching."""
6
7 from __future__ import annotations
8
9 import enum
10 import sys
11 from collections import defaultdict
12 from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple, Union
13
14
15 class ParseError(Exception):
16     """An error encountered while parsing a logical search expression."""
17
18     def __init__(self, message: str):
19         super().__init__()
20         self.message = message
21
22
23 class Document(NamedTuple):
24     """A tuple representing a searchable document."""
25
26     docid: str  # a unique idenfier for the document
27     tags: Set[str]  # an optional set of tags
28     properties: List[Tuple[str, str]]  # an optional set of key->value properties
29     reference: Any  # an optional reference to something else
30
31
32 class Operation(enum.Enum):
33     """A logical search query operation."""
34
35     QUERY = 1
36     CONJUNCTION = 2
37     DISJUNCTION = 3
38     INVERSION = 4
39
40     @staticmethod
41     def from_token(token: str):
42         table = {
43             "not": Operation.INVERSION,
44             "and": Operation.CONJUNCTION,
45             "or": Operation.DISJUNCTION,
46         }
47         return table.get(token, None)
48
49     def num_operands(self) -> Optional[int]:
50         table = {
51             Operation.INVERSION: 1,
52             Operation.CONJUNCTION: 2,
53             Operation.DISJUNCTION: 2,
54         }
55         return table.get(self, None)
56
57
58 class Corpus(object):
59     """A collection of searchable documents.
60
61     >>> c = Corpus()
62     >>> c.add_doc(Document(
63     ...                    docid=1,
64     ...                    tags=set(['urgent', 'important']),
65     ...                    properties=[
66     ...                                ('author', 'Scott'),
67     ...                                ('subject', 'your anniversary')
68     ...                    ],
69     ...                    reference=None,
70     ...                   )
71     ...          )
72     >>> c.add_doc(Document(
73     ...                    docid=2,
74     ...                    tags=set(['important']),
75     ...                    properties=[
76     ...                                ('author', 'Joe'),
77     ...                                ('subject', 'your performance at work')
78     ...                    ],
79     ...                    reference=None,
80     ...                   )
81     ...          )
82     >>> c.add_doc(Document(
83     ...                    docid=3,
84     ...                    tags=set(['urgent']),
85     ...                    properties=[
86     ...                                ('author', 'Scott'),
87     ...                                ('subject', 'car turning in front of you')
88     ...                    ],
89     ...                    reference=None,
90     ...                   )
91     ...          )
92     >>> c.query('author:Scott and important')
93     {1}
94     """
95
96     def __init__(self) -> None:
97         self.docids_by_tag: Dict[str, Set[str]] = defaultdict(set)
98         self.docids_by_property: Dict[Tuple[str, str], Set[str]] = defaultdict(set)
99         self.docids_with_property: Dict[str, Set[str]] = defaultdict(set)
100         self.documents_by_docid: Dict[str, Document] = {}
101
102     def add_doc(self, doc: Document) -> None:
103         """Add a new Document to the Corpus.  Each Document must have a
104         distinct docid that will serve as its primary identifier.  If
105         the same Document is added multiple times, only the most
106         recent addition is indexed.  If two distinct documents with
107         the same docid are added, the latter klobbers the former in the
108         indexes.
109
110         Each Document may have an optional set of tags which can be
111         used later in expressions to the query method.
112
113         Each Document may have an optional list of key->value tuples
114         which can be used later in expressions to the query method.
115
116         Document includes a user-defined "reference" field which is
117         never interpreted by this module.  This is meant to allow easy
118         mapping between Documents in this corpus and external objects
119         they may represent.
120         """
121
122         if doc.docid in self.documents_by_docid:
123             # Handle collisions; assume that we are re-indexing the
124             # same document so remove it from the indexes before
125             # adding it back again.
126             colliding_doc = self.documents_by_docid[doc.docid]
127             assert colliding_doc.docid == doc.docid
128             del self.documents_by_docid[doc.docid]
129             for tag in colliding_doc.tags:
130                 self.docids_by_tag[tag].remove(doc.docid)
131             for key, value in colliding_doc.properties:
132                 self.docids_by_property[(key, value)].remove(doc.docid)
133                 self.docids_with_property[key].remove(doc.docid)
134
135         # Index the new Document
136         assert doc.docid not in self.documents_by_docid
137         self.documents_by_docid[doc.docid] = doc
138         for tag in doc.tags:
139             self.docids_by_tag[tag].add(doc.docid)
140         for key, value in doc.properties:
141             self.docids_by_property[(key, value)].add(doc.docid)
142             self.docids_with_property[key].add(doc.docid)
143
144     def get_docids_by_exact_tag(self, tag: str) -> Set[str]:
145         """Return the set of docids that have a particular tag."""
146
147         return self.docids_by_tag[tag]
148
149     def get_docids_by_searching_tags(self, tag: str) -> Set[str]:
150         """Return the set of docids with a tag that contains a str"""
151
152         ret = set()
153         for search_tag in self.docids_by_tag:
154             if tag in search_tag:
155                 for docid in self.docids_by_tag[search_tag]:
156                     ret.add(docid)
157         return ret
158
159     def get_docids_with_property(self, key: str) -> Set[str]:
160         """Return the set of docids that have a particular property no matter
161         what that property's value.
162
163         """
164         return self.docids_with_property[key]
165
166     def get_docids_by_property(self, key: str, value: str) -> Set[str]:
167         """Return the set of docids that have a particular property with a
168         particular value..
169
170         """
171         return self.docids_by_property[(key, value)]
172
173     def invert_docid_set(self, original: Set[str]) -> Set[str]:
174         """Invert a set of docids."""
175
176         return set([docid for docid in self.documents_by_docid.keys() if docid not in original])
177
178     def get_doc(self, docid: str) -> Optional[Document]:
179         """Given a docid, retrieve the previously added Document."""
180
181         return self.documents_by_docid.get(docid, None)
182
183     def query(self, query: str) -> Optional[Set[str]]:
184         """Query the corpus for documents that match a logical expression.
185         Returns a (potentially empty) set of docids for the matching
186         (previously added) documents or None on error.
187
188         e.g.
189
190         tag1 and tag2 and not tag3
191
192         (tag1 or tag2) and (tag3 or tag4)
193
194         (tag1 and key2:value2) or (tag2 and key1:value1)
195
196         key:*
197
198         tag1 and key:*
199         """
200
201         try:
202             root = self._parse_query(query)
203         except ParseError as e:
204             print(e.message, file=sys.stderr)
205             return None
206         return root.eval()
207
208     def _parse_query(self, query: str):
209         """Internal parse helper; prefer to use query instead."""
210
211         parens = set(["(", ")"])
212         and_or = set(["and", "or"])
213
214         def operator_precedence(token: str) -> Optional[int]:
215             table = {
216                 "(": 4,  # higher
217                 ")": 4,
218                 "not": 3,
219                 "and": 2,
220                 "or": 1,  # lower
221             }
222             return table.get(token, None)
223
224         def is_operator(token: str) -> bool:
225             return operator_precedence(token) is not None
226
227         def lex(query: str):
228             tokens = query.split()
229             for token in tokens:
230                 # Handle ( and ) operators stuck to the ends of tokens
231                 # that split() doesn't understand.
232                 if len(token) > 1:
233                     first = token[0]
234                     if first in parens:
235                         tail = token[1:]
236                         yield first
237                         token = tail
238                     last = token[-1]
239                     if last in parens:
240                         head = token[0:-1]
241                         yield head
242                         token = last
243                 yield token
244
245         def evaluate(corpus: Corpus, stack: List[str]):
246             node_stack: List[Node] = []
247             for token in stack:
248                 node = None
249                 if not is_operator(token):
250                     node = Node(corpus, Operation.QUERY, [token])
251                 else:
252                     args = []
253                     operation = Operation.from_token(token)
254                     operand_count = operation.num_operands()
255                     if len(node_stack) < operand_count:
256                         raise ParseError(f"Incorrect number of operations for {operation}")
257                     for _ in range(operation.num_operands()):
258                         args.append(node_stack.pop())
259                     node = Node(corpus, operation, args)
260                 node_stack.append(node)
261             return node_stack[0]
262
263         output_stack = []
264         operator_stack = []
265         for token in lex(query):
266             if not is_operator(token):
267                 output_stack.append(token)
268                 continue
269
270             # token is an operator...
271             if token == "(":
272                 operator_stack.append(token)
273             elif token == ")":
274                 ok = False
275                 while len(operator_stack) > 0:
276                     pop_operator = operator_stack.pop()
277                     if pop_operator != "(":
278                         output_stack.append(pop_operator)
279                     else:
280                         ok = True
281                         break
282                 if not ok:
283                     raise ParseError("Unbalanced parenthesis in query expression")
284
285             # and, or, not
286             else:
287                 my_precedence = operator_precedence(token)
288                 if my_precedence is None:
289                     raise ParseError(f"Unknown operator: {token}")
290                 while len(operator_stack) > 0:
291                     peek_operator = operator_stack[-1]
292                     if not is_operator(peek_operator) or peek_operator == "(":
293                         break
294                     peek_precedence = operator_precedence(peek_operator)
295                     if peek_precedence is None:
296                         raise ParseError("Internal error")
297                     if (
298                         (peek_precedence < my_precedence)
299                         or (peek_precedence == my_precedence)
300                         and (peek_operator not in and_or)
301                     ):
302                         break
303                     output_stack.append(operator_stack.pop())
304                 operator_stack.append(token)
305         while len(operator_stack) > 0:
306             token = operator_stack.pop()
307             if token in parens:
308                 raise ParseError("Unbalanced parenthesis in query expression")
309             output_stack.append(token)
310         return evaluate(self, output_stack)
311
312
313 class Node(object):
314     """A query AST node."""
315
316     def __init__(
317         self,
318         corpus: Corpus,
319         op: Operation,
320         operands: Sequence[Union[Node, str]],
321     ):
322         self.corpus = corpus
323         self.op = op
324         self.operands = operands
325
326     def eval(self) -> Set[str]:
327         """Evaluate this node."""
328
329         evaled_operands: List[Union[Set[str], str]] = []
330         for operand in self.operands:
331             if isinstance(operand, Node):
332                 evaled_operands.append(operand.eval())
333             elif isinstance(operand, str):
334                 evaled_operands.append(operand)
335             else:
336                 raise ParseError(f"Unexpected operand: {operand}")
337
338         retval = set()
339         if self.op is Operation.QUERY:
340             for tag in evaled_operands:
341                 if isinstance(tag, str):
342                     if ":" in tag:
343                         try:
344                             key, value = tag.split(":")
345                         except ValueError as v:
346                             raise ParseError(f'Invalid key:value syntax at "{tag}"') from v
347                         if value == "*":
348                             r = self.corpus.get_docids_with_property(key)
349                         else:
350                             r = self.corpus.get_docids_by_property(key, value)
351                     else:
352                         r = self.corpus.get_docids_by_exact_tag(tag)
353                     retval.update(r)
354                 else:
355                     raise ParseError(f"Unexpected query {tag}")
356         elif self.op is Operation.DISJUNCTION:
357             if len(evaled_operands) != 2:
358                 raise ParseError("Operation.DISJUNCTION (or) expects two operands.")
359             retval.update(evaled_operands[0])
360             retval.update(evaled_operands[1])
361         elif self.op is Operation.CONJUNCTION:
362             if len(evaled_operands) != 2:
363                 raise ParseError("Operation.CONJUNCTION (and) expects two operands.")
364             retval.update(evaled_operands[0])
365             retval = retval.intersection(evaled_operands[1])
366         elif self.op is Operation.INVERSION:
367             if len(evaled_operands) != 1:
368                 raise ParseError("Operation.INVERSION (not) expects one operand.")
369             _ = evaled_operands[0]
370             if isinstance(_, set):
371                 retval.update(self.corpus.invert_docid_set(_))
372             else:
373                 raise ParseError(f"Unexpected negation operand {_} ({type(_)})")
374         return retval
375
376
377 if __name__ == '__main__':
378     import doctest
379
380     doctest.testmod()