More cleanup.
[python_utils.git] / logical_search.py
1 #!/usr/bin/env python3
2
3 """This is a module concerned with the creation of and searching of a
4 corpus of documents.  The corpus is held in memory for fast
5 searching."""
6
7 from __future__ import annotations
8 import enum
9 import sys
10 from collections import defaultdict
11 from dataclasses import dataclass, field
12 from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
13
14
15 class ParseError(Exception):
16     """An error encountered while parsing a logical search expression."""
17
18     def __init__(self, message: str):
19         super().__init__()
20         self.message = message
21
22
23 @dataclass
24 class Document:
25     """A class representing a searchable document."""
26
27     # A unique identifier for each document.
28     docid: str = ''
29
30     # A set of tag strings for this document.  May be empty.
31     tags: Set[str] = field(default_factory=set)
32
33     # A list of key->value strings for this document.  May be empty.
34     properties: List[Tuple[str, str]] = field(default_factory=list)
35
36     # An optional reference to something else; interpreted only by
37     # caller code, ignored here.
38     reference: Optional[Any] = None
39
40
41 class Operation(enum.Enum):
42     """A logical search query operation."""
43
44     QUERY = 1
45     CONJUNCTION = 2
46     DISJUNCTION = 3
47     INVERSION = 4
48
49     @staticmethod
50     def from_token(token: str):
51         table = {
52             "not": Operation.INVERSION,
53             "and": Operation.CONJUNCTION,
54             "or": Operation.DISJUNCTION,
55         }
56         return table.get(token, None)
57
58     def num_operands(self) -> Optional[int]:
59         table = {
60             Operation.INVERSION: 1,
61             Operation.CONJUNCTION: 2,
62             Operation.DISJUNCTION: 2,
63         }
64         return table.get(self, None)
65
66
67 class Corpus(object):
68     """A collection of searchable documents.
69
70     >>> c = Corpus()
71     >>> c.add_doc(Document(
72     ...                    docid=1,
73     ...                    tags=set(['urgent', 'important']),
74     ...                    properties=[
75     ...                                ('author', 'Scott'),
76     ...                                ('subject', 'your anniversary')
77     ...                    ],
78     ...                    reference=None,
79     ...                   )
80     ...          )
81     >>> c.add_doc(Document(
82     ...                    docid=2,
83     ...                    tags=set(['important']),
84     ...                    properties=[
85     ...                                ('author', 'Joe'),
86     ...                                ('subject', 'your performance at work')
87     ...                    ],
88     ...                    reference=None,
89     ...                   )
90     ...          )
91     >>> c.add_doc(Document(
92     ...                    docid=3,
93     ...                    tags=set(['urgent']),
94     ...                    properties=[
95     ...                                ('author', 'Scott'),
96     ...                                ('subject', 'car turning in front of you')
97     ...                    ],
98     ...                    reference=None,
99     ...                   )
100     ...          )
101     >>> c.query('author:Scott and important')
102     {1}
103     """
104
105     def __init__(self) -> None:
106         self.docids_by_tag: Dict[str, Set[str]] = defaultdict(set)
107         self.docids_by_property: Dict[Tuple[str, str], Set[str]] = defaultdict(set)
108         self.docids_with_property: Dict[str, Set[str]] = defaultdict(set)
109         self.documents_by_docid: Dict[str, Document] = {}
110
111     def add_doc(self, doc: Document) -> None:
112         """Add a new Document to the Corpus.  Each Document must have a
113         distinct docid that will serve as its primary identifier.  If
114         the same Document is added multiple times, only the most
115         recent addition is indexed.  If two distinct documents with
116         the same docid are added, the latter klobbers the former in the
117         indexes.
118
119         Each Document may have an optional set of tags which can be
120         used later in expressions to the query method.
121
122         Each Document may have an optional list of key->value tuples
123         which can be used later in expressions to the query method.
124
125         Document includes a user-defined "reference" field which is
126         never interpreted by this module.  This is meant to allow easy
127         mapping between Documents in this corpus and external objects
128         they may represent.
129         """
130
131         if doc.docid in self.documents_by_docid:
132             # Handle collisions; assume that we are re-indexing the
133             # same document so remove it from the indexes before
134             # adding it back again.
135             colliding_doc = self.documents_by_docid[doc.docid]
136             assert colliding_doc.docid == doc.docid
137             del self.documents_by_docid[doc.docid]
138             for tag in colliding_doc.tags:
139                 self.docids_by_tag[tag].remove(doc.docid)
140             for key, value in colliding_doc.properties:
141                 self.docids_by_property[(key, value)].remove(doc.docid)
142                 self.docids_with_property[key].remove(doc.docid)
143
144         # Index the new Document
145         assert doc.docid not in self.documents_by_docid
146         self.documents_by_docid[doc.docid] = doc
147         for tag in doc.tags:
148             self.docids_by_tag[tag].add(doc.docid)
149         for key, value in doc.properties:
150             self.docids_by_property[(key, value)].add(doc.docid)
151             self.docids_with_property[key].add(doc.docid)
152
153     def get_docids_by_exact_tag(self, tag: str) -> Set[str]:
154         """Return the set of docids that have a particular tag."""
155
156         return self.docids_by_tag[tag]
157
158     def get_docids_by_searching_tags(self, tag: str) -> Set[str]:
159         """Return the set of docids with a tag that contains a str"""
160
161         ret = set()
162         for search_tag in self.docids_by_tag:
163             if tag in search_tag:
164                 for docid in self.docids_by_tag[search_tag]:
165                     ret.add(docid)
166         return ret
167
168     def get_docids_with_property(self, key: str) -> Set[str]:
169         """Return the set of docids that have a particular property no matter
170         what that property's value.
171
172         """
173         return self.docids_with_property[key]
174
175     def get_docids_by_property(self, key: str, value: str) -> Set[str]:
176         """Return the set of docids that have a particular property with a
177         particular value..
178
179         """
180         return self.docids_by_property[(key, value)]
181
182     def invert_docid_set(self, original: Set[str]) -> Set[str]:
183         """Invert a set of docids."""
184
185         return {docid for docid in self.documents_by_docid if docid not in original}
186
187     def get_doc(self, docid: str) -> Optional[Document]:
188         """Given a docid, retrieve the previously added Document."""
189
190         return self.documents_by_docid.get(docid, None)
191
192     def query(self, query: str) -> Optional[Set[str]]:
193         """Query the corpus for documents that match a logical expression.
194         Returns a (potentially empty) set of docids for the matching
195         (previously added) documents or None on error.
196
197         e.g.
198
199         tag1 and tag2 and not tag3
200
201         (tag1 or tag2) and (tag3 or tag4)
202
203         (tag1 and key2:value2) or (tag2 and key1:value1)
204
205         key:*
206
207         tag1 and key:*
208         """
209
210         try:
211             root = self._parse_query(query)
212         except ParseError as e:
213             print(e.message, file=sys.stderr)
214             return None
215         return root.eval()
216
217     def _parse_query(self, query: str):
218         """Internal parse helper; prefer to use query instead."""
219
220         parens = set(["(", ")"])
221         and_or = set(["and", "or"])
222
223         def operator_precedence(token: str) -> Optional[int]:
224             table = {
225                 "(": 4,  # higher
226                 ")": 4,
227                 "not": 3,
228                 "and": 2,
229                 "or": 1,  # lower
230             }
231             return table.get(token, None)
232
233         def is_operator(token: str) -> bool:
234             return operator_precedence(token) is not None
235
236         def lex(query: str):
237             tokens = query.split()
238             for token in tokens:
239                 # Handle ( and ) operators stuck to the ends of tokens
240                 # that split() doesn't understand.
241                 if len(token) > 1:
242                     first = token[0]
243                     if first in parens:
244                         tail = token[1:]
245                         yield first
246                         token = tail
247                     last = token[-1]
248                     if last in parens:
249                         head = token[0:-1]
250                         yield head
251                         token = last
252                 yield token
253
254         def evaluate(corpus: Corpus, stack: List[str]):
255             node_stack: List[Node] = []
256             for token in stack:
257                 node = None
258                 if not is_operator(token):
259                     node = Node(corpus, Operation.QUERY, [token])
260                 else:
261                     args = []
262                     operation = Operation.from_token(token)
263                     operand_count = operation.num_operands()
264                     if len(node_stack) < operand_count:
265                         raise ParseError(f"Incorrect number of operations for {operation}")
266                     for _ in range(operation.num_operands()):
267                         args.append(node_stack.pop())
268                     node = Node(corpus, operation, args)
269                 node_stack.append(node)
270             return node_stack[0]
271
272         output_stack = []
273         operator_stack = []
274         for token in lex(query):
275             if not is_operator(token):
276                 output_stack.append(token)
277                 continue
278
279             # token is an operator...
280             if token == "(":
281                 operator_stack.append(token)
282             elif token == ")":
283                 ok = False
284                 while len(operator_stack) > 0:
285                     pop_operator = operator_stack.pop()
286                     if pop_operator != "(":
287                         output_stack.append(pop_operator)
288                     else:
289                         ok = True
290                         break
291                 if not ok:
292                     raise ParseError("Unbalanced parenthesis in query expression")
293
294             # and, or, not
295             else:
296                 my_precedence = operator_precedence(token)
297                 if my_precedence is None:
298                     raise ParseError(f"Unknown operator: {token}")
299                 while len(operator_stack) > 0:
300                     peek_operator = operator_stack[-1]
301                     if not is_operator(peek_operator) or peek_operator == "(":
302                         break
303                     peek_precedence = operator_precedence(peek_operator)
304                     if peek_precedence is None:
305                         raise ParseError("Internal error")
306                     if (
307                         (peek_precedence < my_precedence)
308                         or (peek_precedence == my_precedence)
309                         and (peek_operator not in and_or)
310                     ):
311                         break
312                     output_stack.append(operator_stack.pop())
313                 operator_stack.append(token)
314         while len(operator_stack) > 0:
315             token = operator_stack.pop()
316             if token in parens:
317                 raise ParseError("Unbalanced parenthesis in query expression")
318             output_stack.append(token)
319         return evaluate(self, output_stack)
320
321
322 class Node(object):
323     """A query AST node."""
324
325     def __init__(
326         self,
327         corpus: Corpus,
328         op: Operation,
329         operands: Sequence[Union[Node, str]],
330     ):
331         self.corpus = corpus
332         self.op = op
333         self.operands = operands
334
335     def eval(self) -> Set[str]:
336         """Evaluate this node."""
337
338         evaled_operands: List[Union[Set[str], str]] = []
339         for operand in self.operands:
340             if isinstance(operand, Node):
341                 evaled_operands.append(operand.eval())
342             elif isinstance(operand, str):
343                 evaled_operands.append(operand)
344             else:
345                 raise ParseError(f"Unexpected operand: {operand}")
346
347         retval = set()
348         if self.op is Operation.QUERY:
349             for tag in evaled_operands:
350                 if isinstance(tag, str):
351                     if ":" in tag:
352                         try:
353                             key, value = tag.split(":")
354                         except ValueError as v:
355                             raise ParseError(f'Invalid key:value syntax at "{tag}"') from v
356                         if value == "*":
357                             r = self.corpus.get_docids_with_property(key)
358                         else:
359                             r = self.corpus.get_docids_by_property(key, value)
360                     else:
361                         r = self.corpus.get_docids_by_exact_tag(tag)
362                     retval.update(r)
363                 else:
364                     raise ParseError(f"Unexpected query {tag}")
365         elif self.op is Operation.DISJUNCTION:
366             if len(evaled_operands) != 2:
367                 raise ParseError("Operation.DISJUNCTION (or) expects two operands.")
368             retval.update(evaled_operands[0])
369             retval.update(evaled_operands[1])
370         elif self.op is Operation.CONJUNCTION:
371             if len(evaled_operands) != 2:
372                 raise ParseError("Operation.CONJUNCTION (and) expects two operands.")
373             retval.update(evaled_operands[0])
374             retval = retval.intersection(evaled_operands[1])
375         elif self.op is Operation.INVERSION:
376             if len(evaled_operands) != 1:
377                 raise ParseError("Operation.INVERSION (not) expects one operand.")
378             _ = evaled_operands[0]
379             if isinstance(_, set):
380                 retval.update(self.corpus.invert_docid_set(_))
381             else:
382                 raise ParseError(f"Unexpected negation operand {_} ({type(_)})")
383         return retval
384
385
386 if __name__ == '__main__':
387     import doctest
388
389     doctest.testmod()