Fix an edge condition around the Nth weekday of the month when
[python_utils.git] / logical_search.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """This is a module concerned with the creation of and searching of a
6 corpus of documents.  The corpus is held in memory for fast
7 searching.
8
9 """
10
11 from __future__ import annotations
12 import enum
13 import sys
14 from collections import defaultdict
15 from dataclasses import dataclass, field
16 from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
17
18
19 class ParseError(Exception):
20     """An error encountered while parsing a logical search expression."""
21
22     def __init__(self, message: str):
23         super().__init__()
24         self.message = message
25
26
27 @dataclass
28 class Document:
29     """A class representing a searchable document."""
30
31     # A unique identifier for each document.
32     docid: str = ''
33
34     # A set of tag strings for this document.  May be empty.
35     tags: Set[str] = field(default_factory=set)
36
37     # A list of key->value strings for this document.  May be empty.
38     properties: List[Tuple[str, str]] = field(default_factory=list)
39
40     # An optional reference to something else; interpreted only by
41     # caller code, ignored here.
42     reference: Optional[Any] = None
43
44
45 class Operation(enum.Enum):
46     """A logical search query operation."""
47
48     QUERY = 1
49     CONJUNCTION = 2
50     DISJUNCTION = 3
51     INVERSION = 4
52
53     @staticmethod
54     def from_token(token: str):
55         table = {
56             "not": Operation.INVERSION,
57             "and": Operation.CONJUNCTION,
58             "or": Operation.DISJUNCTION,
59         }
60         return table.get(token, None)
61
62     def num_operands(self) -> Optional[int]:
63         table = {
64             Operation.INVERSION: 1,
65             Operation.CONJUNCTION: 2,
66             Operation.DISJUNCTION: 2,
67         }
68         return table.get(self, None)
69
70
71 class Corpus(object):
72     """A collection of searchable documents.
73
74     >>> c = Corpus()
75     >>> c.add_doc(Document(
76     ...                    docid=1,
77     ...                    tags=set(['urgent', 'important']),
78     ...                    properties=[
79     ...                                ('author', 'Scott'),
80     ...                                ('subject', 'your anniversary')
81     ...                    ],
82     ...                    reference=None,
83     ...                   )
84     ...          )
85     >>> c.add_doc(Document(
86     ...                    docid=2,
87     ...                    tags=set(['important']),
88     ...                    properties=[
89     ...                                ('author', 'Joe'),
90     ...                                ('subject', 'your performance at work')
91     ...                    ],
92     ...                    reference=None,
93     ...                   )
94     ...          )
95     >>> c.add_doc(Document(
96     ...                    docid=3,
97     ...                    tags=set(['urgent']),
98     ...                    properties=[
99     ...                                ('author', 'Scott'),
100     ...                                ('subject', 'car turning in front of you')
101     ...                    ],
102     ...                    reference=None,
103     ...                   )
104     ...          )
105     >>> c.query('author:Scott and important')
106     {1}
107     >>> c.query('*')
108     {1, 2, 3}
109     >>> c.query('*:*')
110     {1, 2, 3}
111     >>> c.query('*:Scott')
112     {1, 3}
113     """
114
115     def __init__(self) -> None:
116         self.docids_by_tag: Dict[str, Set[str]] = defaultdict(set)
117         self.docids_by_property: Dict[Tuple[str, str], Set[str]] = defaultdict(set)
118         self.docids_with_property: Dict[str, Set[str]] = defaultdict(set)
119         self.documents_by_docid: Dict[str, Document] = {}
120
121     def add_doc(self, doc: Document) -> None:
122         """Add a new Document to the Corpus.  Each Document must have a
123         distinct docid that will serve as its primary identifier.  If
124         the same Document is added multiple times, only the most
125         recent addition is indexed.  If two distinct documents with
126         the same docid are added, the latter klobbers the former in the
127         indexes.
128
129         Each Document may have an optional set of tags which can be
130         used later in expressions to the query method.
131
132         Each Document may have an optional list of key->value tuples
133         which can be used later in expressions to the query method.
134
135         Document includes a user-defined "reference" field which is
136         never interpreted by this module.  This is meant to allow easy
137         mapping between Documents in this corpus and external objects
138         they may represent.
139         """
140
141         if doc.docid in self.documents_by_docid:
142             # Handle collisions; assume that we are re-indexing the
143             # same document so remove it from the indexes before
144             # adding it back again.
145             colliding_doc = self.documents_by_docid[doc.docid]
146             assert colliding_doc.docid == doc.docid
147             del self.documents_by_docid[doc.docid]
148             for tag in colliding_doc.tags:
149                 self.docids_by_tag[tag].remove(doc.docid)
150             for key, value in colliding_doc.properties:
151                 self.docids_by_property[(key, value)].remove(doc.docid)
152                 self.docids_with_property[key].remove(doc.docid)
153
154         # Index the new Document
155         assert doc.docid not in self.documents_by_docid
156         self.documents_by_docid[doc.docid] = doc
157         for tag in doc.tags:
158             self.docids_by_tag[tag].add(doc.docid)
159         for key, value in doc.properties:
160             self.docids_by_property[(key, value)].add(doc.docid)
161             self.docids_with_property[key].add(doc.docid)
162
163     def get_docids_by_exact_tag(self, tag: str) -> Set[str]:
164         """Return the set of docids that have a particular tag."""
165         return self.docids_by_tag[tag]
166
167     def get_docids_by_searching_tags(self, tag: str) -> Set[str]:
168         """Return the set of docids with a tag that contains a str"""
169
170         ret = set()
171         for search_tag in self.docids_by_tag:
172             if tag in search_tag:
173                 for docid in self.docids_by_tag[search_tag]:
174                     ret.add(docid)
175         return ret
176
177     def get_docids_with_property(self, key: str) -> Set[str]:
178         """Return the set of docids that have a particular property no matter
179         what that property's value.
180
181         """
182         return self.docids_with_property[key]
183
184     def get_docids_by_property(self, key: str, value: str) -> Set[str]:
185         """Return the set of docids that have a particular property with a
186         particular value..
187
188         """
189         return self.docids_by_property[(key, value)]
190
191     def invert_docid_set(self, original: Set[str]) -> Set[str]:
192         """Invert a set of docids."""
193
194         return {docid for docid in self.documents_by_docid if docid not in original}
195
196     def get_doc(self, docid: str) -> Optional[Document]:
197         """Given a docid, retrieve the previously added Document."""
198
199         return self.documents_by_docid.get(docid, None)
200
201     def query(self, query: str) -> Optional[Set[str]]:
202         """Query the corpus for documents that match a logical expression.
203         Returns a (potentially empty) set of docids for the matching
204         (previously added) documents or None on error.
205
206         e.g.
207
208         tag1 and tag2 and not tag3
209
210         (tag1 or tag2) and (tag3 or tag4)
211
212         (tag1 and key2:value2) or (tag2 and key1:value1)
213
214         key:*
215
216         tag1 and key:*
217         """
218
219         try:
220             root = self._parse_query(query)
221         except ParseError as e:
222             print(e.message, file=sys.stderr)
223             return None
224         return root.eval()
225
226     def _parse_query(self, query: str):
227         """Internal parse helper; prefer to use query instead."""
228
229         parens = set(["(", ")"])
230         and_or = set(["and", "or"])
231
232         def operator_precedence(token: str) -> Optional[int]:
233             table = {
234                 "(": 4,  # higher
235                 ")": 4,
236                 "not": 3,
237                 "and": 2,
238                 "or": 1,  # lower
239             }
240             return table.get(token, None)
241
242         def is_operator(token: str) -> bool:
243             return operator_precedence(token) is not None
244
245         def lex(query: str):
246             tokens = query.split()
247             for token in tokens:
248                 # Handle ( and ) operators stuck to the ends of tokens
249                 # that split() doesn't understand.
250                 if len(token) > 1:
251                     first = token[0]
252                     if first in parens:
253                         tail = token[1:]
254                         yield first
255                         token = tail
256                     last = token[-1]
257                     if last in parens:
258                         head = token[0:-1]
259                         yield head
260                         token = last
261                 yield token
262
263         def evaluate(corpus: Corpus, stack: List[str]):
264             node_stack: List[Node] = []
265             for token in stack:
266                 node = None
267                 if not is_operator(token):
268                     node = Node(corpus, Operation.QUERY, [token])
269                 else:
270                     args = []
271                     operation = Operation.from_token(token)
272                     operand_count = operation.num_operands()
273                     if len(node_stack) < operand_count:
274                         raise ParseError(f"Incorrect number of operations for {operation}")
275                     for _ in range(operation.num_operands()):
276                         args.append(node_stack.pop())
277                     node = Node(corpus, operation, args)
278                 node_stack.append(node)
279             return node_stack[0]
280
281         output_stack = []
282         operator_stack = []
283         for token in lex(query):
284             if not is_operator(token):
285                 output_stack.append(token)
286                 continue
287
288             # token is an operator...
289             if token == "(":
290                 operator_stack.append(token)
291             elif token == ")":
292                 ok = False
293                 while len(operator_stack) > 0:
294                     pop_operator = operator_stack.pop()
295                     if pop_operator != "(":
296                         output_stack.append(pop_operator)
297                     else:
298                         ok = True
299                         break
300                 if not ok:
301                     raise ParseError("Unbalanced parenthesis in query expression")
302
303             # and, or, not
304             else:
305                 my_precedence = operator_precedence(token)
306                 if my_precedence is None:
307                     raise ParseError(f"Unknown operator: {token}")
308                 while len(operator_stack) > 0:
309                     peek_operator = operator_stack[-1]
310                     if not is_operator(peek_operator) or peek_operator == "(":
311                         break
312                     peek_precedence = operator_precedence(peek_operator)
313                     if peek_precedence is None:
314                         raise ParseError("Internal error")
315                     if (
316                         (peek_precedence < my_precedence)
317                         or (peek_precedence == my_precedence)
318                         and (peek_operator not in and_or)
319                     ):
320                         break
321                     output_stack.append(operator_stack.pop())
322                 operator_stack.append(token)
323         while len(operator_stack) > 0:
324             token = operator_stack.pop()
325             if token in parens:
326                 raise ParseError("Unbalanced parenthesis in query expression")
327             output_stack.append(token)
328         return evaluate(self, output_stack)
329
330
331 class Node(object):
332     """A query AST node."""
333
334     def __init__(
335         self,
336         corpus: Corpus,
337         op: Operation,
338         operands: Sequence[Union[Node, str]],
339     ):
340         self.corpus = corpus
341         self.op = op
342         self.operands = operands
343
344     def eval(self) -> Set[str]:
345         """Evaluate this node."""
346
347         evaled_operands: List[Union[Set[str], str]] = []
348         for operand in self.operands:
349             if isinstance(operand, Node):
350                 evaled_operands.append(operand.eval())
351             elif isinstance(operand, str):
352                 evaled_operands.append(operand)
353             else:
354                 raise ParseError(f"Unexpected operand: {operand}")
355
356         retval = set()
357         if self.op is Operation.QUERY:
358             for tag in evaled_operands:
359                 if isinstance(tag, str):
360                     if ":" in tag:
361                         try:
362                             key, value = tag.split(":")
363                         except ValueError as v:
364                             raise ParseError(f'Invalid key:value syntax at "{tag}"') from v
365
366                         if key == '*':
367                             r = set()
368                             for kv, s in self.corpus.docids_by_property.items():
369                                 if value in ('*', kv[1]):
370                                     r.update(s)
371                         else:
372                             if value == '*':
373                                 r = self.corpus.get_docids_with_property(key)
374                             else:
375                                 r = self.corpus.get_docids_by_property(key, value)
376                     else:
377                         if tag == '*':
378                             r = set()
379                             for s in self.corpus.docids_by_tag.values():
380                                 r.update(s)
381                         else:
382                             r = self.corpus.get_docids_by_exact_tag(tag)
383                     retval.update(r)
384                 else:
385                     raise ParseError(f"Unexpected query {tag}")
386         elif self.op is Operation.DISJUNCTION:
387             if len(evaled_operands) != 2:
388                 raise ParseError("Operation.DISJUNCTION (or) expects two operands.")
389             retval.update(evaled_operands[0])
390             retval.update(evaled_operands[1])
391         elif self.op is Operation.CONJUNCTION:
392             if len(evaled_operands) != 2:
393                 raise ParseError("Operation.CONJUNCTION (and) expects two operands.")
394             retval.update(evaled_operands[0])
395             retval = retval.intersection(evaled_operands[1])
396         elif self.op is Operation.INVERSION:
397             if len(evaled_operands) != 1:
398                 raise ParseError("Operation.INVERSION (not) expects one operand.")
399             _ = evaled_operands[0]
400             if isinstance(_, set):
401                 retval.update(self.corpus.invert_docid_set(_))
402             else:
403                 raise ParseError(f"Unexpected negation operand {_} ({type(_)})")
404         return retval
405
406
407 if __name__ == '__main__':
408     import doctest
409
410     doctest.testmod()