Preformatted box that doesn't wrap the contents.
[python_utils.git] / logical_search.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """This is a module concerned with the creation of and searching of a
6 corpus of documents.  The corpus is held in memory for fast
7 searching.
8
9 """
10
11 from __future__ import annotations
12 import enum
13 import sys
14 from collections import defaultdict
15 from dataclasses import dataclass, field
16 from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
17
18
19 class ParseError(Exception):
20     """An error encountered while parsing a logical search expression."""
21
22     def __init__(self, message: str):
23         super().__init__()
24         self.message = message
25
26
27 @dataclass
28 class Document:
29     """A class representing a searchable document."""
30
31     # A unique identifier for each document.
32     docid: str = ''
33
34     # A set of tag strings for this document.  May be empty.
35     tags: Set[str] = field(default_factory=set)
36
37     # A list of key->value strings for this document.  May be empty.
38     properties: List[Tuple[str, str]] = field(default_factory=list)
39
40     # An optional reference to something else; interpreted only by
41     # caller code, ignored here.
42     reference: Optional[Any] = None
43
44
45 class Operation(enum.Enum):
46     """A logical search query operation."""
47
48     QUERY = 1
49     CONJUNCTION = 2
50     DISJUNCTION = 3
51     INVERSION = 4
52
53     @staticmethod
54     def from_token(token: str):
55         table = {
56             "not": Operation.INVERSION,
57             "and": Operation.CONJUNCTION,
58             "or": Operation.DISJUNCTION,
59         }
60         return table.get(token, None)
61
62     def num_operands(self) -> Optional[int]:
63         table = {
64             Operation.INVERSION: 1,
65             Operation.CONJUNCTION: 2,
66             Operation.DISJUNCTION: 2,
67         }
68         return table.get(self, None)
69
70
71 class Corpus(object):
72     """A collection of searchable documents.
73
74     >>> c = Corpus()
75     >>> c.add_doc(Document(
76     ...                    docid=1,
77     ...                    tags=set(['urgent', 'important']),
78     ...                    properties=[
79     ...                                ('author', 'Scott'),
80     ...                                ('subject', 'your anniversary')
81     ...                    ],
82     ...                    reference=None,
83     ...                   )
84     ...          )
85     >>> c.add_doc(Document(
86     ...                    docid=2,
87     ...                    tags=set(['important']),
88     ...                    properties=[
89     ...                                ('author', 'Joe'),
90     ...                                ('subject', 'your performance at work')
91     ...                    ],
92     ...                    reference=None,
93     ...                   )
94     ...          )
95     >>> c.add_doc(Document(
96     ...                    docid=3,
97     ...                    tags=set(['urgent']),
98     ...                    properties=[
99     ...                                ('author', 'Scott'),
100     ...                                ('subject', 'car turning in front of you')
101     ...                    ],
102     ...                    reference=None,
103     ...                   )
104     ...          )
105     >>> c.query('author:Scott and important')
106     {1}
107     >>> c.query('*')
108     {1, 2, 3}
109     """
110
111     def __init__(self) -> None:
112         self.docids_by_tag: Dict[str, Set[str]] = defaultdict(set)
113         self.docids_by_property: Dict[Tuple[str, str], Set[str]] = defaultdict(set)
114         self.docids_with_property: Dict[str, Set[str]] = defaultdict(set)
115         self.documents_by_docid: Dict[str, Document] = {}
116
117     def add_doc(self, doc: Document) -> None:
118         """Add a new Document to the Corpus.  Each Document must have a
119         distinct docid that will serve as its primary identifier.  If
120         the same Document is added multiple times, only the most
121         recent addition is indexed.  If two distinct documents with
122         the same docid are added, the latter klobbers the former in the
123         indexes.
124
125         Each Document may have an optional set of tags which can be
126         used later in expressions to the query method.
127
128         Each Document may have an optional list of key->value tuples
129         which can be used later in expressions to the query method.
130
131         Document includes a user-defined "reference" field which is
132         never interpreted by this module.  This is meant to allow easy
133         mapping between Documents in this corpus and external objects
134         they may represent.
135         """
136
137         if doc.docid in self.documents_by_docid:
138             # Handle collisions; assume that we are re-indexing the
139             # same document so remove it from the indexes before
140             # adding it back again.
141             colliding_doc = self.documents_by_docid[doc.docid]
142             assert colliding_doc.docid == doc.docid
143             del self.documents_by_docid[doc.docid]
144             for tag in colliding_doc.tags:
145                 self.docids_by_tag[tag].remove(doc.docid)
146             for key, value in colliding_doc.properties:
147                 self.docids_by_property[(key, value)].remove(doc.docid)
148                 self.docids_with_property[key].remove(doc.docid)
149
150         # Index the new Document
151         assert doc.docid not in self.documents_by_docid
152         self.documents_by_docid[doc.docid] = doc
153         for tag in doc.tags:
154             self.docids_by_tag[tag].add(doc.docid)
155         for key, value in doc.properties:
156             self.docids_by_property[(key, value)].add(doc.docid)
157             self.docids_with_property[key].add(doc.docid)
158
159     def get_docids_by_exact_tag(self, tag: str) -> Set[str]:
160         """Return the set of docids that have a particular tag."""
161
162         return self.docids_by_tag[tag]
163
164     def get_docids_by_searching_tags(self, tag: str) -> Set[str]:
165         """Return the set of docids with a tag that contains a str"""
166
167         ret = set()
168         for search_tag in self.docids_by_tag:
169             if tag in search_tag:
170                 for docid in self.docids_by_tag[search_tag]:
171                     ret.add(docid)
172         return ret
173
174     def get_docids_with_property(self, key: str) -> Set[str]:
175         """Return the set of docids that have a particular property no matter
176         what that property's value.
177
178         """
179         return self.docids_with_property[key]
180
181     def get_docids_by_property(self, key: str, value: str) -> Set[str]:
182         """Return the set of docids that have a particular property with a
183         particular value..
184
185         """
186         return self.docids_by_property[(key, value)]
187
188     def invert_docid_set(self, original: Set[str]) -> Set[str]:
189         """Invert a set of docids."""
190
191         return {docid for docid in self.documents_by_docid if docid not in original}
192
193     def get_doc(self, docid: str) -> Optional[Document]:
194         """Given a docid, retrieve the previously added Document."""
195
196         return self.documents_by_docid.get(docid, None)
197
198     def query(self, query: str) -> Optional[Set[str]]:
199         """Query the corpus for documents that match a logical expression.
200         Returns a (potentially empty) set of docids for the matching
201         (previously added) documents or None on error.
202
203         e.g.
204
205         tag1 and tag2 and not tag3
206
207         (tag1 or tag2) and (tag3 or tag4)
208
209         (tag1 and key2:value2) or (tag2 and key1:value1)
210
211         key:*
212
213         tag1 and key:*
214         """
215
216         if query == '*':
217             return set(self.documents_by_docid.keys())
218         try:
219             root = self._parse_query(query)
220         except ParseError as e:
221             print(e.message, file=sys.stderr)
222             return None
223         return root.eval()
224
225     def _parse_query(self, query: str):
226         """Internal parse helper; prefer to use query instead."""
227
228         parens = set(["(", ")"])
229         and_or = set(["and", "or"])
230
231         def operator_precedence(token: str) -> Optional[int]:
232             table = {
233                 "(": 4,  # higher
234                 ")": 4,
235                 "not": 3,
236                 "and": 2,
237                 "or": 1,  # lower
238             }
239             return table.get(token, None)
240
241         def is_operator(token: str) -> bool:
242             return operator_precedence(token) is not None
243
244         def lex(query: str):
245             tokens = query.split()
246             for token in tokens:
247                 # Handle ( and ) operators stuck to the ends of tokens
248                 # that split() doesn't understand.
249                 if len(token) > 1:
250                     first = token[0]
251                     if first in parens:
252                         tail = token[1:]
253                         yield first
254                         token = tail
255                     last = token[-1]
256                     if last in parens:
257                         head = token[0:-1]
258                         yield head
259                         token = last
260                 yield token
261
262         def evaluate(corpus: Corpus, stack: List[str]):
263             node_stack: List[Node] = []
264             for token in stack:
265                 node = None
266                 if not is_operator(token):
267                     node = Node(corpus, Operation.QUERY, [token])
268                 else:
269                     args = []
270                     operation = Operation.from_token(token)
271                     operand_count = operation.num_operands()
272                     if len(node_stack) < operand_count:
273                         raise ParseError(f"Incorrect number of operations for {operation}")
274                     for _ in range(operation.num_operands()):
275                         args.append(node_stack.pop())
276                     node = Node(corpus, operation, args)
277                 node_stack.append(node)
278             return node_stack[0]
279
280         output_stack = []
281         operator_stack = []
282         for token in lex(query):
283             if not is_operator(token):
284                 output_stack.append(token)
285                 continue
286
287             # token is an operator...
288             if token == "(":
289                 operator_stack.append(token)
290             elif token == ")":
291                 ok = False
292                 while len(operator_stack) > 0:
293                     pop_operator = operator_stack.pop()
294                     if pop_operator != "(":
295                         output_stack.append(pop_operator)
296                     else:
297                         ok = True
298                         break
299                 if not ok:
300                     raise ParseError("Unbalanced parenthesis in query expression")
301
302             # and, or, not
303             else:
304                 my_precedence = operator_precedence(token)
305                 if my_precedence is None:
306                     raise ParseError(f"Unknown operator: {token}")
307                 while len(operator_stack) > 0:
308                     peek_operator = operator_stack[-1]
309                     if not is_operator(peek_operator) or peek_operator == "(":
310                         break
311                     peek_precedence = operator_precedence(peek_operator)
312                     if peek_precedence is None:
313                         raise ParseError("Internal error")
314                     if (
315                         (peek_precedence < my_precedence)
316                         or (peek_precedence == my_precedence)
317                         and (peek_operator not in and_or)
318                     ):
319                         break
320                     output_stack.append(operator_stack.pop())
321                 operator_stack.append(token)
322         while len(operator_stack) > 0:
323             token = operator_stack.pop()
324             if token in parens:
325                 raise ParseError("Unbalanced parenthesis in query expression")
326             output_stack.append(token)
327         return evaluate(self, output_stack)
328
329
330 class Node(object):
331     """A query AST node."""
332
333     def __init__(
334         self,
335         corpus: Corpus,
336         op: Operation,
337         operands: Sequence[Union[Node, str]],
338     ):
339         self.corpus = corpus
340         self.op = op
341         self.operands = operands
342
343     def eval(self) -> Set[str]:
344         """Evaluate this node."""
345
346         evaled_operands: List[Union[Set[str], str]] = []
347         for operand in self.operands:
348             if isinstance(operand, Node):
349                 evaled_operands.append(operand.eval())
350             elif isinstance(operand, str):
351                 evaled_operands.append(operand)
352             else:
353                 raise ParseError(f"Unexpected operand: {operand}")
354
355         retval = set()
356         if self.op is Operation.QUERY:
357             for tag in evaled_operands:
358                 if isinstance(tag, str):
359                     if ":" in tag:
360                         try:
361                             key, value = tag.split(":")
362                         except ValueError as v:
363                             raise ParseError(f'Invalid key:value syntax at "{tag}"') from v
364                         if value == "*":
365                             r = self.corpus.get_docids_with_property(key)
366                         else:
367                             r = self.corpus.get_docids_by_property(key, value)
368                     else:
369                         r = self.corpus.get_docids_by_exact_tag(tag)
370                     retval.update(r)
371                 else:
372                     raise ParseError(f"Unexpected query {tag}")
373         elif self.op is Operation.DISJUNCTION:
374             if len(evaled_operands) != 2:
375                 raise ParseError("Operation.DISJUNCTION (or) expects two operands.")
376             retval.update(evaled_operands[0])
377             retval.update(evaled_operands[1])
378         elif self.op is Operation.CONJUNCTION:
379             if len(evaled_operands) != 2:
380                 raise ParseError("Operation.CONJUNCTION (and) expects two operands.")
381             retval.update(evaled_operands[0])
382             retval = retval.intersection(evaled_operands[1])
383         elif self.op is Operation.INVERSION:
384             if len(evaled_operands) != 1:
385                 raise ParseError("Operation.INVERSION (not) expects one operand.")
386             _ = evaled_operands[0]
387             if isinstance(_, set):
388                 retval.update(self.corpus.invert_docid_set(_))
389             else:
390                 raise ParseError(f"Unexpected negation operand {_} ({type(_)})")
391         return retval
392
393
394 if __name__ == '__main__':
395     import doctest
396
397     doctest.testmod()