Do the wildcard search thing better.
[python_utils.git] / logical_search.py
1 #!/usr/bin/env python3
2
3 # © Copyright 2021-2022, Scott Gasch
4
5 """This is a module concerned with the creation of and searching of a
6 corpus of documents.  The corpus is held in memory for fast
7 searching.
8
9 """
10
11 from __future__ import annotations
12 import enum
13 import sys
14 from collections import defaultdict
15 from dataclasses import dataclass, field
16 from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
17
18
19 class ParseError(Exception):
20     """An error encountered while parsing a logical search expression."""
21
22     def __init__(self, message: str):
23         super().__init__()
24         self.message = message
25
26
27 @dataclass
28 class Document:
29     """A class representing a searchable document."""
30
31     # A unique identifier for each document.
32     docid: str = ''
33
34     # A set of tag strings for this document.  May be empty.
35     tags: Set[str] = field(default_factory=set)
36
37     # A list of key->value strings for this document.  May be empty.
38     properties: List[Tuple[str, str]] = field(default_factory=list)
39
40     # An optional reference to something else; interpreted only by
41     # caller code, ignored here.
42     reference: Optional[Any] = None
43
44
45 class Operation(enum.Enum):
46     """A logical search query operation."""
47
48     QUERY = 1
49     CONJUNCTION = 2
50     DISJUNCTION = 3
51     INVERSION = 4
52
53     @staticmethod
54     def from_token(token: str):
55         table = {
56             "not": Operation.INVERSION,
57             "and": Operation.CONJUNCTION,
58             "or": Operation.DISJUNCTION,
59         }
60         return table.get(token, None)
61
62     def num_operands(self) -> Optional[int]:
63         table = {
64             Operation.INVERSION: 1,
65             Operation.CONJUNCTION: 2,
66             Operation.DISJUNCTION: 2,
67         }
68         return table.get(self, None)
69
70
71 class Corpus(object):
72     """A collection of searchable documents.
73
74     >>> c = Corpus()
75     >>> c.add_doc(Document(
76     ...                    docid=1,
77     ...                    tags=set(['urgent', 'important']),
78     ...                    properties=[
79     ...                                ('author', 'Scott'),
80     ...                                ('subject', 'your anniversary')
81     ...                    ],
82     ...                    reference=None,
83     ...                   )
84     ...          )
85     >>> c.add_doc(Document(
86     ...                    docid=2,
87     ...                    tags=set(['important']),
88     ...                    properties=[
89     ...                                ('author', 'Joe'),
90     ...                                ('subject', 'your performance at work')
91     ...                    ],
92     ...                    reference=None,
93     ...                   )
94     ...          )
95     >>> c.add_doc(Document(
96     ...                    docid=3,
97     ...                    tags=set(['urgent']),
98     ...                    properties=[
99     ...                                ('author', 'Scott'),
100     ...                                ('subject', 'car turning in front of you')
101     ...                    ],
102     ...                    reference=None,
103     ...                   )
104     ...          )
105     >>> c.query('author:Scott and important')
106     {1}
107     >>> c.query('*')
108     {1, 2, 3}
109     >>> c.query('*:*')
110     {1, 2, 3}
111     """
112
113     def __init__(self) -> None:
114         self.docids_by_tag: Dict[str, Set[str]] = defaultdict(set)
115         self.docids_by_property: Dict[Tuple[str, str], Set[str]] = defaultdict(set)
116         self.docids_with_property: Dict[str, Set[str]] = defaultdict(set)
117         self.documents_by_docid: Dict[str, Document] = {}
118
119     def add_doc(self, doc: Document) -> None:
120         """Add a new Document to the Corpus.  Each Document must have a
121         distinct docid that will serve as its primary identifier.  If
122         the same Document is added multiple times, only the most
123         recent addition is indexed.  If two distinct documents with
124         the same docid are added, the latter klobbers the former in the
125         indexes.
126
127         Each Document may have an optional set of tags which can be
128         used later in expressions to the query method.
129
130         Each Document may have an optional list of key->value tuples
131         which can be used later in expressions to the query method.
132
133         Document includes a user-defined "reference" field which is
134         never interpreted by this module.  This is meant to allow easy
135         mapping between Documents in this corpus and external objects
136         they may represent.
137         """
138
139         if doc.docid in self.documents_by_docid:
140             # Handle collisions; assume that we are re-indexing the
141             # same document so remove it from the indexes before
142             # adding it back again.
143             colliding_doc = self.documents_by_docid[doc.docid]
144             assert colliding_doc.docid == doc.docid
145             del self.documents_by_docid[doc.docid]
146             for tag in colliding_doc.tags:
147                 self.docids_by_tag[tag].remove(doc.docid)
148             for key, value in colliding_doc.properties:
149                 self.docids_by_property[(key, value)].remove(doc.docid)
150                 self.docids_with_property[key].remove(doc.docid)
151
152         # Index the new Document
153         assert doc.docid not in self.documents_by_docid
154         self.documents_by_docid[doc.docid] = doc
155         for tag in doc.tags:
156             self.docids_by_tag[tag].add(doc.docid)
157         for key, value in doc.properties:
158             self.docids_by_property[(key, value)].add(doc.docid)
159             self.docids_with_property[key].add(doc.docid)
160
161     def get_docids_by_exact_tag(self, tag: str) -> Set[str]:
162         """Return the set of docids that have a particular tag."""
163         return self.docids_by_tag[tag]
164
165     def get_docids_by_searching_tags(self, tag: str) -> Set[str]:
166         """Return the set of docids with a tag that contains a str"""
167
168         ret = set()
169         for search_tag in self.docids_by_tag:
170             if tag in search_tag:
171                 for docid in self.docids_by_tag[search_tag]:
172                     ret.add(docid)
173         return ret
174
175     def get_docids_with_property(self, key: str) -> Set[str]:
176         """Return the set of docids that have a particular property no matter
177         what that property's value.
178
179         """
180         return self.docids_with_property[key]
181
182     def get_docids_by_property(self, key: str, value: str) -> Set[str]:
183         """Return the set of docids that have a particular property with a
184         particular value..
185
186         """
187         return self.docids_by_property[(key, value)]
188
189     def invert_docid_set(self, original: Set[str]) -> Set[str]:
190         """Invert a set of docids."""
191
192         return {docid for docid in self.documents_by_docid if docid not in original}
193
194     def get_doc(self, docid: str) -> Optional[Document]:
195         """Given a docid, retrieve the previously added Document."""
196
197         return self.documents_by_docid.get(docid, None)
198
199     def query(self, query: str) -> Optional[Set[str]]:
200         """Query the corpus for documents that match a logical expression.
201         Returns a (potentially empty) set of docids for the matching
202         (previously added) documents or None on error.
203
204         e.g.
205
206         tag1 and tag2 and not tag3
207
208         (tag1 or tag2) and (tag3 or tag4)
209
210         (tag1 and key2:value2) or (tag2 and key1:value1)
211
212         key:*
213
214         tag1 and key:*
215         """
216
217         try:
218             root = self._parse_query(query)
219         except ParseError as e:
220             print(e.message, file=sys.stderr)
221             return None
222         return root.eval()
223
224     def _parse_query(self, query: str):
225         """Internal parse helper; prefer to use query instead."""
226
227         parens = set(["(", ")"])
228         and_or = set(["and", "or"])
229
230         def operator_precedence(token: str) -> Optional[int]:
231             table = {
232                 "(": 4,  # higher
233                 ")": 4,
234                 "not": 3,
235                 "and": 2,
236                 "or": 1,  # lower
237             }
238             return table.get(token, None)
239
240         def is_operator(token: str) -> bool:
241             return operator_precedence(token) is not None
242
243         def lex(query: str):
244             tokens = query.split()
245             for token in tokens:
246                 # Handle ( and ) operators stuck to the ends of tokens
247                 # that split() doesn't understand.
248                 if len(token) > 1:
249                     first = token[0]
250                     if first in parens:
251                         tail = token[1:]
252                         yield first
253                         token = tail
254                     last = token[-1]
255                     if last in parens:
256                         head = token[0:-1]
257                         yield head
258                         token = last
259                 yield token
260
261         def evaluate(corpus: Corpus, stack: List[str]):
262             node_stack: List[Node] = []
263             for token in stack:
264                 node = None
265                 if not is_operator(token):
266                     node = Node(corpus, Operation.QUERY, [token])
267                 else:
268                     args = []
269                     operation = Operation.from_token(token)
270                     operand_count = operation.num_operands()
271                     if len(node_stack) < operand_count:
272                         raise ParseError(f"Incorrect number of operations for {operation}")
273                     for _ in range(operation.num_operands()):
274                         args.append(node_stack.pop())
275                     node = Node(corpus, operation, args)
276                 node_stack.append(node)
277             return node_stack[0]
278
279         output_stack = []
280         operator_stack = []
281         for token in lex(query):
282             if not is_operator(token):
283                 output_stack.append(token)
284                 continue
285
286             # token is an operator...
287             if token == "(":
288                 operator_stack.append(token)
289             elif token == ")":
290                 ok = False
291                 while len(operator_stack) > 0:
292                     pop_operator = operator_stack.pop()
293                     if pop_operator != "(":
294                         output_stack.append(pop_operator)
295                     else:
296                         ok = True
297                         break
298                 if not ok:
299                     raise ParseError("Unbalanced parenthesis in query expression")
300
301             # and, or, not
302             else:
303                 my_precedence = operator_precedence(token)
304                 if my_precedence is None:
305                     raise ParseError(f"Unknown operator: {token}")
306                 while len(operator_stack) > 0:
307                     peek_operator = operator_stack[-1]
308                     if not is_operator(peek_operator) or peek_operator == "(":
309                         break
310                     peek_precedence = operator_precedence(peek_operator)
311                     if peek_precedence is None:
312                         raise ParseError("Internal error")
313                     if (
314                         (peek_precedence < my_precedence)
315                         or (peek_precedence == my_precedence)
316                         and (peek_operator not in and_or)
317                     ):
318                         break
319                     output_stack.append(operator_stack.pop())
320                 operator_stack.append(token)
321         while len(operator_stack) > 0:
322             token = operator_stack.pop()
323             if token in parens:
324                 raise ParseError("Unbalanced parenthesis in query expression")
325             output_stack.append(token)
326         return evaluate(self, output_stack)
327
328
329 class Node(object):
330     """A query AST node."""
331
332     def __init__(
333         self,
334         corpus: Corpus,
335         op: Operation,
336         operands: Sequence[Union[Node, str]],
337     ):
338         self.corpus = corpus
339         self.op = op
340         self.operands = operands
341
342     def eval(self) -> Set[str]:
343         """Evaluate this node."""
344
345         evaled_operands: List[Union[Set[str], str]] = []
346         for operand in self.operands:
347             if isinstance(operand, Node):
348                 evaled_operands.append(operand.eval())
349             elif isinstance(operand, str):
350                 evaled_operands.append(operand)
351             else:
352                 raise ParseError(f"Unexpected operand: {operand}")
353
354         retval = set()
355         if self.op is Operation.QUERY:
356             for tag in evaled_operands:
357                 if isinstance(tag, str):
358                     if ":" in tag:
359                         try:
360                             key, value = tag.split(":")
361                         except ValueError as v:
362                             raise ParseError(f'Invalid key:value syntax at "{tag}"') from v
363                         if key == '*':
364                             r = set()
365                             for s in self.corpus.docids_by_tag.values():
366                                 r.update(s)
367                         else:
368                             if value == '*':
369                                 r = self.corpus.get_docids_with_property(key)
370                             else:
371                                 r = self.corpus.get_docids_by_property(key, value)
372                     else:
373                         if tag == '*':
374                             r = set()
375                             for s in self.corpus.docids_by_tag.values():
376                                 r.update(s)
377                         else:
378                             r = self.corpus.get_docids_by_exact_tag(tag)
379                     retval.update(r)
380                 else:
381                     raise ParseError(f"Unexpected query {tag}")
382         elif self.op is Operation.DISJUNCTION:
383             if len(evaled_operands) != 2:
384                 raise ParseError("Operation.DISJUNCTION (or) expects two operands.")
385             retval.update(evaled_operands[0])
386             retval.update(evaled_operands[1])
387         elif self.op is Operation.CONJUNCTION:
388             if len(evaled_operands) != 2:
389                 raise ParseError("Operation.CONJUNCTION (and) expects two operands.")
390             retval.update(evaled_operands[0])
391             retval = retval.intersection(evaled_operands[1])
392         elif self.op is Operation.INVERSION:
393             if len(evaled_operands) != 1:
394                 raise ParseError("Operation.INVERSION (not) expects one operand.")
395             _ = evaled_operands[0]
396             if isinstance(_, set):
397                 retval.update(self.corpus.invert_docid_set(_))
398             else:
399                 raise ParseError(f"Unexpected negation operand {_} ({type(_)})")
400         return retval
401
402
403 if __name__ == '__main__':
404     import doctest
405
406     doctest.testmod()