Ditch named tuples for dataclasses.
[python_utils.git] / logical_search.py
1 #!/usr/bin/env python3
2
3 """This is a module concerned with the creation of and searching of a
4 corpus of documents.  The corpus is held in memory for fast
5 searching."""
6
7 from __future__ import annotations
8
9 import enum
10 import sys
11 from collections import defaultdict
12 from dataclasses import dataclass, field
13 from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
14
15
16 class ParseError(Exception):
17     """An error encountered while parsing a logical search expression."""
18
19     def __init__(self, message: str):
20         super().__init__()
21         self.message = message
22
23
24 @dataclass
25 class Document:
26     """A class representing a searchable document."""
27
28     # A unique identifier for each document.
29     docid: str = ''
30
31     # A set of tag strings for this document.  May be empty.
32     tags: Set[str] = field(default_factory=set)
33
34     # A list of key->value strings for this document.  May be empty.
35     properties: List[Tuple[str, str]] = field(default_factory=list)
36
37     # An optional reference to something else; interpreted only by
38     # caller code, ignored here.
39     reference: Optional[Any] = None
40
41
42 class Operation(enum.Enum):
43     """A logical search query operation."""
44
45     QUERY = 1
46     CONJUNCTION = 2
47     DISJUNCTION = 3
48     INVERSION = 4
49
50     @staticmethod
51     def from_token(token: str):
52         table = {
53             "not": Operation.INVERSION,
54             "and": Operation.CONJUNCTION,
55             "or": Operation.DISJUNCTION,
56         }
57         return table.get(token, None)
58
59     def num_operands(self) -> Optional[int]:
60         table = {
61             Operation.INVERSION: 1,
62             Operation.CONJUNCTION: 2,
63             Operation.DISJUNCTION: 2,
64         }
65         return table.get(self, None)
66
67
68 class Corpus(object):
69     """A collection of searchable documents.
70
71     >>> c = Corpus()
72     >>> c.add_doc(Document(
73     ...                    docid=1,
74     ...                    tags=set(['urgent', 'important']),
75     ...                    properties=[
76     ...                                ('author', 'Scott'),
77     ...                                ('subject', 'your anniversary')
78     ...                    ],
79     ...                    reference=None,
80     ...                   )
81     ...          )
82     >>> c.add_doc(Document(
83     ...                    docid=2,
84     ...                    tags=set(['important']),
85     ...                    properties=[
86     ...                                ('author', 'Joe'),
87     ...                                ('subject', 'your performance at work')
88     ...                    ],
89     ...                    reference=None,
90     ...                   )
91     ...          )
92     >>> c.add_doc(Document(
93     ...                    docid=3,
94     ...                    tags=set(['urgent']),
95     ...                    properties=[
96     ...                                ('author', 'Scott'),
97     ...                                ('subject', 'car turning in front of you')
98     ...                    ],
99     ...                    reference=None,
100     ...                   )
101     ...          )
102     >>> c.query('author:Scott and important')
103     {1}
104     """
105
106     def __init__(self) -> None:
107         self.docids_by_tag: Dict[str, Set[str]] = defaultdict(set)
108         self.docids_by_property: Dict[Tuple[str, str], Set[str]] = defaultdict(set)
109         self.docids_with_property: Dict[str, Set[str]] = defaultdict(set)
110         self.documents_by_docid: Dict[str, Document] = {}
111
112     def add_doc(self, doc: Document) -> None:
113         """Add a new Document to the Corpus.  Each Document must have a
114         distinct docid that will serve as its primary identifier.  If
115         the same Document is added multiple times, only the most
116         recent addition is indexed.  If two distinct documents with
117         the same docid are added, the latter klobbers the former in the
118         indexes.
119
120         Each Document may have an optional set of tags which can be
121         used later in expressions to the query method.
122
123         Each Document may have an optional list of key->value tuples
124         which can be used later in expressions to the query method.
125
126         Document includes a user-defined "reference" field which is
127         never interpreted by this module.  This is meant to allow easy
128         mapping between Documents in this corpus and external objects
129         they may represent.
130         """
131
132         if doc.docid in self.documents_by_docid:
133             # Handle collisions; assume that we are re-indexing the
134             # same document so remove it from the indexes before
135             # adding it back again.
136             colliding_doc = self.documents_by_docid[doc.docid]
137             assert colliding_doc.docid == doc.docid
138             del self.documents_by_docid[doc.docid]
139             for tag in colliding_doc.tags:
140                 self.docids_by_tag[tag].remove(doc.docid)
141             for key, value in colliding_doc.properties:
142                 self.docids_by_property[(key, value)].remove(doc.docid)
143                 self.docids_with_property[key].remove(doc.docid)
144
145         # Index the new Document
146         assert doc.docid not in self.documents_by_docid
147         self.documents_by_docid[doc.docid] = doc
148         for tag in doc.tags:
149             self.docids_by_tag[tag].add(doc.docid)
150         for key, value in doc.properties:
151             self.docids_by_property[(key, value)].add(doc.docid)
152             self.docids_with_property[key].add(doc.docid)
153
154     def get_docids_by_exact_tag(self, tag: str) -> Set[str]:
155         """Return the set of docids that have a particular tag."""
156
157         return self.docids_by_tag[tag]
158
159     def get_docids_by_searching_tags(self, tag: str) -> Set[str]:
160         """Return the set of docids with a tag that contains a str"""
161
162         ret = set()
163         for search_tag in self.docids_by_tag:
164             if tag in search_tag:
165                 for docid in self.docids_by_tag[search_tag]:
166                     ret.add(docid)
167         return ret
168
169     def get_docids_with_property(self, key: str) -> Set[str]:
170         """Return the set of docids that have a particular property no matter
171         what that property's value.
172
173         """
174         return self.docids_with_property[key]
175
176     def get_docids_by_property(self, key: str, value: str) -> Set[str]:
177         """Return the set of docids that have a particular property with a
178         particular value..
179
180         """
181         return self.docids_by_property[(key, value)]
182
183     def invert_docid_set(self, original: Set[str]) -> Set[str]:
184         """Invert a set of docids."""
185
186         return set([docid for docid in self.documents_by_docid.keys() if docid not in original])
187
188     def get_doc(self, docid: str) -> Optional[Document]:
189         """Given a docid, retrieve the previously added Document."""
190
191         return self.documents_by_docid.get(docid, None)
192
193     def query(self, query: str) -> Optional[Set[str]]:
194         """Query the corpus for documents that match a logical expression.
195         Returns a (potentially empty) set of docids for the matching
196         (previously added) documents or None on error.
197
198         e.g.
199
200         tag1 and tag2 and not tag3
201
202         (tag1 or tag2) and (tag3 or tag4)
203
204         (tag1 and key2:value2) or (tag2 and key1:value1)
205
206         key:*
207
208         tag1 and key:*
209         """
210
211         try:
212             root = self._parse_query(query)
213         except ParseError as e:
214             print(e.message, file=sys.stderr)
215             return None
216         return root.eval()
217
218     def _parse_query(self, query: str):
219         """Internal parse helper; prefer to use query instead."""
220
221         parens = set(["(", ")"])
222         and_or = set(["and", "or"])
223
224         def operator_precedence(token: str) -> Optional[int]:
225             table = {
226                 "(": 4,  # higher
227                 ")": 4,
228                 "not": 3,
229                 "and": 2,
230                 "or": 1,  # lower
231             }
232             return table.get(token, None)
233
234         def is_operator(token: str) -> bool:
235             return operator_precedence(token) is not None
236
237         def lex(query: str):
238             tokens = query.split()
239             for token in tokens:
240                 # Handle ( and ) operators stuck to the ends of tokens
241                 # that split() doesn't understand.
242                 if len(token) > 1:
243                     first = token[0]
244                     if first in parens:
245                         tail = token[1:]
246                         yield first
247                         token = tail
248                     last = token[-1]
249                     if last in parens:
250                         head = token[0:-1]
251                         yield head
252                         token = last
253                 yield token
254
255         def evaluate(corpus: Corpus, stack: List[str]):
256             node_stack: List[Node] = []
257             for token in stack:
258                 node = None
259                 if not is_operator(token):
260                     node = Node(corpus, Operation.QUERY, [token])
261                 else:
262                     args = []
263                     operation = Operation.from_token(token)
264                     operand_count = operation.num_operands()
265                     if len(node_stack) < operand_count:
266                         raise ParseError(f"Incorrect number of operations for {operation}")
267                     for _ in range(operation.num_operands()):
268                         args.append(node_stack.pop())
269                     node = Node(corpus, operation, args)
270                 node_stack.append(node)
271             return node_stack[0]
272
273         output_stack = []
274         operator_stack = []
275         for token in lex(query):
276             if not is_operator(token):
277                 output_stack.append(token)
278                 continue
279
280             # token is an operator...
281             if token == "(":
282                 operator_stack.append(token)
283             elif token == ")":
284                 ok = False
285                 while len(operator_stack) > 0:
286                     pop_operator = operator_stack.pop()
287                     if pop_operator != "(":
288                         output_stack.append(pop_operator)
289                     else:
290                         ok = True
291                         break
292                 if not ok:
293                     raise ParseError("Unbalanced parenthesis in query expression")
294
295             # and, or, not
296             else:
297                 my_precedence = operator_precedence(token)
298                 if my_precedence is None:
299                     raise ParseError(f"Unknown operator: {token}")
300                 while len(operator_stack) > 0:
301                     peek_operator = operator_stack[-1]
302                     if not is_operator(peek_operator) or peek_operator == "(":
303                         break
304                     peek_precedence = operator_precedence(peek_operator)
305                     if peek_precedence is None:
306                         raise ParseError("Internal error")
307                     if (
308                         (peek_precedence < my_precedence)
309                         or (peek_precedence == my_precedence)
310                         and (peek_operator not in and_or)
311                     ):
312                         break
313                     output_stack.append(operator_stack.pop())
314                 operator_stack.append(token)
315         while len(operator_stack) > 0:
316             token = operator_stack.pop()
317             if token in parens:
318                 raise ParseError("Unbalanced parenthesis in query expression")
319             output_stack.append(token)
320         return evaluate(self, output_stack)
321
322
323 class Node(object):
324     """A query AST node."""
325
326     def __init__(
327         self,
328         corpus: Corpus,
329         op: Operation,
330         operands: Sequence[Union[Node, str]],
331     ):
332         self.corpus = corpus
333         self.op = op
334         self.operands = operands
335
336     def eval(self) -> Set[str]:
337         """Evaluate this node."""
338
339         evaled_operands: List[Union[Set[str], str]] = []
340         for operand in self.operands:
341             if isinstance(operand, Node):
342                 evaled_operands.append(operand.eval())
343             elif isinstance(operand, str):
344                 evaled_operands.append(operand)
345             else:
346                 raise ParseError(f"Unexpected operand: {operand}")
347
348         retval = set()
349         if self.op is Operation.QUERY:
350             for tag in evaled_operands:
351                 if isinstance(tag, str):
352                     if ":" in tag:
353                         try:
354                             key, value = tag.split(":")
355                         except ValueError as v:
356                             raise ParseError(f'Invalid key:value syntax at "{tag}"') from v
357                         if value == "*":
358                             r = self.corpus.get_docids_with_property(key)
359                         else:
360                             r = self.corpus.get_docids_by_property(key, value)
361                     else:
362                         r = self.corpus.get_docids_by_exact_tag(tag)
363                     retval.update(r)
364                 else:
365                     raise ParseError(f"Unexpected query {tag}")
366         elif self.op is Operation.DISJUNCTION:
367             if len(evaled_operands) != 2:
368                 raise ParseError("Operation.DISJUNCTION (or) expects two operands.")
369             retval.update(evaled_operands[0])
370             retval.update(evaled_operands[1])
371         elif self.op is Operation.CONJUNCTION:
372             if len(evaled_operands) != 2:
373                 raise ParseError("Operation.CONJUNCTION (and) expects two operands.")
374             retval.update(evaled_operands[0])
375             retval = retval.intersection(evaled_operands[1])
376         elif self.op is Operation.INVERSION:
377             if len(evaled_operands) != 1:
378                 raise ParseError("Operation.INVERSION (not) expects one operand.")
379             _ = evaled_operands[0]
380             if isinstance(_, set):
381                 retval.update(self.corpus.invert_docid_set(_))
382             else:
383                 raise ParseError(f"Unexpected negation operand {_} ({type(_)})")
384         return retval
385
386
387 if __name__ == '__main__':
388     import doctest
389
390     doctest.testmod()