b6d7479879010d6ea40ef813d03e84574ead7e55
[python_utils.git] / logical_search.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """This is a module concerned with the creation of and searching of a
6 corpus of documents.  The corpus is held in memory for fast
7 searching.
8
9 """
10
11 from __future__ import annotations
12 import enum
13 import sys
14 from collections import defaultdict
15 from dataclasses import dataclass, field
16 from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
17
18
19 class ParseError(Exception):
20     """An error encountered while parsing a logical search expression."""
21
22     def __init__(self, message: str):
23         super().__init__()
24         self.message = message
25
26
27 @dataclass
28 class Document:
29     """A class representing a searchable document."""
30
31     # A unique identifier for each document.
32     docid: str = ''
33
34     # A set of tag strings for this document.  May be empty.
35     tags: Set[str] = field(default_factory=set)
36
37     # A list of key->value strings for this document.  May be empty.
38     properties: List[Tuple[str, str]] = field(default_factory=list)
39
40     # An optional reference to something else; interpreted only by
41     # caller code, ignored here.
42     reference: Optional[Any] = None
43
44
45 class Operation(enum.Enum):
46     """A logical search query operation."""
47
48     QUERY = 1
49     CONJUNCTION = 2
50     DISJUNCTION = 3
51     INVERSION = 4
52
53     @staticmethod
54     def from_token(token: str):
55         table = {
56             "not": Operation.INVERSION,
57             "and": Operation.CONJUNCTION,
58             "or": Operation.DISJUNCTION,
59         }
60         return table.get(token, None)
61
62     def num_operands(self) -> Optional[int]:
63         table = {
64             Operation.INVERSION: 1,
65             Operation.CONJUNCTION: 2,
66             Operation.DISJUNCTION: 2,
67         }
68         return table.get(self, None)
69
70
71 class Corpus(object):
72     """A collection of searchable documents.
73
74     >>> c = Corpus()
75     >>> c.add_doc(Document(
76     ...                    docid=1,
77     ...                    tags=set(['urgent', 'important']),
78     ...                    properties=[
79     ...                                ('author', 'Scott'),
80     ...                                ('subject', 'your anniversary')
81     ...                    ],
82     ...                    reference=None,
83     ...                   )
84     ...          )
85     >>> c.add_doc(Document(
86     ...                    docid=2,
87     ...                    tags=set(['important']),
88     ...                    properties=[
89     ...                                ('author', 'Joe'),
90     ...                                ('subject', 'your performance at work')
91     ...                    ],
92     ...                    reference=None,
93     ...                   )
94     ...          )
95     >>> c.add_doc(Document(
96     ...                    docid=3,
97     ...                    tags=set(['urgent']),
98     ...                    properties=[
99     ...                                ('author', 'Scott'),
100     ...                                ('subject', 'car turning in front of you')
101     ...                    ],
102     ...                    reference=None,
103     ...                   )
104     ...          )
105     >>> c.query('author:Scott and important')
106     {1}
107     """
108
109     def __init__(self) -> None:
110         self.docids_by_tag: Dict[str, Set[str]] = defaultdict(set)
111         self.docids_by_property: Dict[Tuple[str, str], Set[str]] = defaultdict(set)
112         self.docids_with_property: Dict[str, Set[str]] = defaultdict(set)
113         self.documents_by_docid: Dict[str, Document] = {}
114
115     def add_doc(self, doc: Document) -> None:
116         """Add a new Document to the Corpus.  Each Document must have a
117         distinct docid that will serve as its primary identifier.  If
118         the same Document is added multiple times, only the most
119         recent addition is indexed.  If two distinct documents with
120         the same docid are added, the latter klobbers the former in the
121         indexes.
122
123         Each Document may have an optional set of tags which can be
124         used later in expressions to the query method.
125
126         Each Document may have an optional list of key->value tuples
127         which can be used later in expressions to the query method.
128
129         Document includes a user-defined "reference" field which is
130         never interpreted by this module.  This is meant to allow easy
131         mapping between Documents in this corpus and external objects
132         they may represent.
133         """
134
135         if doc.docid in self.documents_by_docid:
136             # Handle collisions; assume that we are re-indexing the
137             # same document so remove it from the indexes before
138             # adding it back again.
139             colliding_doc = self.documents_by_docid[doc.docid]
140             assert colliding_doc.docid == doc.docid
141             del self.documents_by_docid[doc.docid]
142             for tag in colliding_doc.tags:
143                 self.docids_by_tag[tag].remove(doc.docid)
144             for key, value in colliding_doc.properties:
145                 self.docids_by_property[(key, value)].remove(doc.docid)
146                 self.docids_with_property[key].remove(doc.docid)
147
148         # Index the new Document
149         assert doc.docid not in self.documents_by_docid
150         self.documents_by_docid[doc.docid] = doc
151         for tag in doc.tags:
152             self.docids_by_tag[tag].add(doc.docid)
153         for key, value in doc.properties:
154             self.docids_by_property[(key, value)].add(doc.docid)
155             self.docids_with_property[key].add(doc.docid)
156
157     def get_docids_by_exact_tag(self, tag: str) -> Set[str]:
158         """Return the set of docids that have a particular tag."""
159
160         return self.docids_by_tag[tag]
161
162     def get_docids_by_searching_tags(self, tag: str) -> Set[str]:
163         """Return the set of docids with a tag that contains a str"""
164
165         ret = set()
166         for search_tag in self.docids_by_tag:
167             if tag in search_tag:
168                 for docid in self.docids_by_tag[search_tag]:
169                     ret.add(docid)
170         return ret
171
172     def get_docids_with_property(self, key: str) -> Set[str]:
173         """Return the set of docids that have a particular property no matter
174         what that property's value.
175
176         """
177         return self.docids_with_property[key]
178
179     def get_docids_by_property(self, key: str, value: str) -> Set[str]:
180         """Return the set of docids that have a particular property with a
181         particular value..
182
183         """
184         return self.docids_by_property[(key, value)]
185
186     def invert_docid_set(self, original: Set[str]) -> Set[str]:
187         """Invert a set of docids."""
188
189         return {docid for docid in self.documents_by_docid if docid not in original}
190
191     def get_doc(self, docid: str) -> Optional[Document]:
192         """Given a docid, retrieve the previously added Document."""
193
194         return self.documents_by_docid.get(docid, None)
195
196     def query(self, query: str) -> Optional[Set[str]]:
197         """Query the corpus for documents that match a logical expression.
198         Returns a (potentially empty) set of docids for the matching
199         (previously added) documents or None on error.
200
201         e.g.
202
203         tag1 and tag2 and not tag3
204
205         (tag1 or tag2) and (tag3 or tag4)
206
207         (tag1 and key2:value2) or (tag2 and key1:value1)
208
209         key:*
210
211         tag1 and key:*
212         """
213
214         try:
215             root = self._parse_query(query)
216         except ParseError as e:
217             print(e.message, file=sys.stderr)
218             return None
219         return root.eval()
220
221     def _parse_query(self, query: str):
222         """Internal parse helper; prefer to use query instead."""
223
224         parens = set(["(", ")"])
225         and_or = set(["and", "or"])
226
227         def operator_precedence(token: str) -> Optional[int]:
228             table = {
229                 "(": 4,  # higher
230                 ")": 4,
231                 "not": 3,
232                 "and": 2,
233                 "or": 1,  # lower
234             }
235             return table.get(token, None)
236
237         def is_operator(token: str) -> bool:
238             return operator_precedence(token) is not None
239
240         def lex(query: str):
241             tokens = query.split()
242             for token in tokens:
243                 # Handle ( and ) operators stuck to the ends of tokens
244                 # that split() doesn't understand.
245                 if len(token) > 1:
246                     first = token[0]
247                     if first in parens:
248                         tail = token[1:]
249                         yield first
250                         token = tail
251                     last = token[-1]
252                     if last in parens:
253                         head = token[0:-1]
254                         yield head
255                         token = last
256                 yield token
257
258         def evaluate(corpus: Corpus, stack: List[str]):
259             node_stack: List[Node] = []
260             for token in stack:
261                 node = None
262                 if not is_operator(token):
263                     node = Node(corpus, Operation.QUERY, [token])
264                 else:
265                     args = []
266                     operation = Operation.from_token(token)
267                     operand_count = operation.num_operands()
268                     if len(node_stack) < operand_count:
269                         raise ParseError(f"Incorrect number of operations for {operation}")
270                     for _ in range(operation.num_operands()):
271                         args.append(node_stack.pop())
272                     node = Node(corpus, operation, args)
273                 node_stack.append(node)
274             return node_stack[0]
275
276         output_stack = []
277         operator_stack = []
278         for token in lex(query):
279             if not is_operator(token):
280                 output_stack.append(token)
281                 continue
282
283             # token is an operator...
284             if token == "(":
285                 operator_stack.append(token)
286             elif token == ")":
287                 ok = False
288                 while len(operator_stack) > 0:
289                     pop_operator = operator_stack.pop()
290                     if pop_operator != "(":
291                         output_stack.append(pop_operator)
292                     else:
293                         ok = True
294                         break
295                 if not ok:
296                     raise ParseError("Unbalanced parenthesis in query expression")
297
298             # and, or, not
299             else:
300                 my_precedence = operator_precedence(token)
301                 if my_precedence is None:
302                     raise ParseError(f"Unknown operator: {token}")
303                 while len(operator_stack) > 0:
304                     peek_operator = operator_stack[-1]
305                     if not is_operator(peek_operator) or peek_operator == "(":
306                         break
307                     peek_precedence = operator_precedence(peek_operator)
308                     if peek_precedence is None:
309                         raise ParseError("Internal error")
310                     if (
311                         (peek_precedence < my_precedence)
312                         or (peek_precedence == my_precedence)
313                         and (peek_operator not in and_or)
314                     ):
315                         break
316                     output_stack.append(operator_stack.pop())
317                 operator_stack.append(token)
318         while len(operator_stack) > 0:
319             token = operator_stack.pop()
320             if token in parens:
321                 raise ParseError("Unbalanced parenthesis in query expression")
322             output_stack.append(token)
323         return evaluate(self, output_stack)
324
325
326 class Node(object):
327     """A query AST node."""
328
329     def __init__(
330         self,
331         corpus: Corpus,
332         op: Operation,
333         operands: Sequence[Union[Node, str]],
334     ):
335         self.corpus = corpus
336         self.op = op
337         self.operands = operands
338
339     def eval(self) -> Set[str]:
340         """Evaluate this node."""
341
342         evaled_operands: List[Union[Set[str], str]] = []
343         for operand in self.operands:
344             if isinstance(operand, Node):
345                 evaled_operands.append(operand.eval())
346             elif isinstance(operand, str):
347                 evaled_operands.append(operand)
348             else:
349                 raise ParseError(f"Unexpected operand: {operand}")
350
351         retval = set()
352         if self.op is Operation.QUERY:
353             for tag in evaled_operands:
354                 if isinstance(tag, str):
355                     if ":" in tag:
356                         try:
357                             key, value = tag.split(":")
358                         except ValueError as v:
359                             raise ParseError(f'Invalid key:value syntax at "{tag}"') from v
360                         if value == "*":
361                             r = self.corpus.get_docids_with_property(key)
362                         else:
363                             r = self.corpus.get_docids_by_property(key, value)
364                     else:
365                         r = self.corpus.get_docids_by_exact_tag(tag)
366                     retval.update(r)
367                 else:
368                     raise ParseError(f"Unexpected query {tag}")
369         elif self.op is Operation.DISJUNCTION:
370             if len(evaled_operands) != 2:
371                 raise ParseError("Operation.DISJUNCTION (or) expects two operands.")
372             retval.update(evaled_operands[0])
373             retval.update(evaled_operands[1])
374         elif self.op is Operation.CONJUNCTION:
375             if len(evaled_operands) != 2:
376                 raise ParseError("Operation.CONJUNCTION (and) expects two operands.")
377             retval.update(evaled_operands[0])
378             retval = retval.intersection(evaled_operands[1])
379         elif self.op is Operation.INVERSION:
380             if len(evaled_operands) != 1:
381                 raise ParseError("Operation.INVERSION (not) expects one operand.")
382             _ = evaled_operands[0]
383             if isinstance(_, set):
384                 retval.update(self.corpus.invert_docid_set(_))
385             else:
386                 raise ParseError(f"Unexpected negation operand {_} ({type(_)})")
387         return retval
388
389
390 if __name__ == '__main__':
391     import doctest
392
393     doctest.testmod()