Teach chord parser about Minor7
[python_utils.git] / music / chords.py
1 #!/usr/bin/env python3
2 # type: ignore
3 # pylint: disable=W0201
4 # pylint: disable=R0904
5
6 # © Copyright 2021-2022, Scott Gasch
7
8 """Parse music chords; work in progress..."""
9
10 import functools
11 import itertools
12 import logging
13 import re
14 import sys
15 from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple
16
17 import antlr4  # type: ignore
18
19 import acl
20 import bootstrap
21 import decorator_utils
22 import list_utils
23 from music.chordsLexer import chordsLexer  # type: ignore
24 from music.chordsListener import chordsListener  # type: ignore
25 from music.chordsParser import chordsParser  # type: ignore
26
27 logger = logging.getLogger(__name__)
28
29
30 def debug_parse(enter_or_exit_f: Callable[[Any, Any], None]):
31     @functools.wraps(enter_or_exit_f)
32     def debug_parse_wrapper(*args, **kwargs):
33         # slf = args[0]
34         ctx = args[1]
35         depth = ctx.depth()
36         logger.debug(
37             '  ' * (depth - 1)
38             + f'Entering {enter_or_exit_f.__name__} ({ctx.invokingState} / {ctx.exception})'
39         )
40         for c in ctx.getChildren():
41             logger.debug('  ' * (depth - 1) + f'{c} {type(c)}')
42         retval = enter_or_exit_f(*args, **kwargs)
43         return retval
44
45     return debug_parse_wrapper
46
47
48 class ParseException(Exception):
49     """An exception thrown during parsing because of unrecognized input."""
50
51     def __init__(self, message: str) -> None:
52         super().__init__()
53         self.message = message
54
55
56 class RaisingErrorListener(antlr4.DiagnosticErrorListener):
57     """An error listener that raises ParseExceptions."""
58
59     def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e):
60         raise ParseException(msg)
61
62     def reportAmbiguity(self, recognizer, dfa, startIndex, stopIndex, exact, ambigAlts, configs):
63         pass
64
65     def reportAttemptingFullContext(
66         self, recognizer, dfa, startIndex, stopIndex, conflictingAlts, configs
67     ):
68         pass
69
70     def reportContextSensitivity(self, recognizer, dfa, startIndex, stopIndex, prediction, configs):
71         pass
72
73
74 class Chord:
75     NOTES = ['A', 'A#', 'B', 'C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#']
76
77     def __init__(self, root_note: str, other_notes: Set[int]) -> None:
78         self.root_note = root_note.upper()
79         self.other_notes: Dict[int, Optional[str]] = {}
80         for semitone_count in other_notes:
81             name = Chord.describe_interval(semitone_count)
82             if name:
83                 self.other_notes[semitone_count] = name
84             else:
85                 self.other_notes[semitone_count] = None
86         self.scale = list(itertools.islice(self.generate_scale(), 24))
87
88     def generate_scale(self) -> Iterator[str]:
89         starting_note = self.root_note
90         start = 0
91         while Chord.NOTES[start] != starting_note:
92             start += 1
93         while True:
94             if start >= len(Chord.NOTES):
95                 start = 0
96             yield Chord.NOTES[start]
97             start += 1
98
99     @staticmethod
100     def describe_interval(semitone_count: int) -> Optional[str]:
101         names: Dict[int, str] = {
102             0: 'perfect unison',
103             1: 'minor 2nd',
104             2: 'major 2nd',
105             3: 'minor 3rd',
106             4: 'major 3rd',
107             5: 'perfect 4th',
108             6: 'diminished 5th',
109             7: 'perfect 5th',
110             8: 'minor 6th',
111             9: 'major 6th',
112             10: 'minor 7th',
113             11: 'major 7th',
114             12: 'perfect octave',
115             13: 'minor 9th',
116             14: 'major 9th',
117             15: 'minor 10th',
118             16: 'major 10th',
119             17: 'perfect 11th',
120             18: 'diminished 12th',
121             19: 'perfect 12th',
122             20: 'minor 13th',
123             21: 'major 13th',
124             22: 'minor 14th',
125             23: 'major 14th',
126             24: 'double octave',
127             25: 'augmented 15th',
128         }
129         return names.get(semitone_count, None)
130
131     def degree_of_note(self, target_note: str) -> Optional[Tuple[Optional[str], int]]:
132         target_note = target_note.upper()
133         for degree, note in enumerate(itertools.islice(self.generate_scale(), 24)):
134             if note == target_note:
135                 return (Chord.describe_interval(degree), degree)
136         return None
137
138     def note_at_degree(self, semitone_count: int) -> Optional[str]:
139         if semitone_count < len(self.scale):
140             return self.scale[semitone_count]
141         return None
142
143     def add_bass(self, bass_note: str) -> None:
144         name, semitone_count = self.degree_of_note(bass_note)
145         if name:
146             self.other_notes[semitone_count] = f'{name} (bass)'
147         else:
148             self.other_notes[semitone_count] = 'bass note'
149
150     def describe_chord(self) -> str:
151         names: Dict[Tuple[int], str] = {
152             (2, 7): 'sus2',
153             (4, 7): 'major',
154             (5, 7): 'sus4',
155             (3, 7): 'minor',
156             (4, 7, 10): 'dom7',
157             (4, 7, 11): 'maj7',
158             (4, 8): 'aug',
159             (3, 5): 'dim',
160             (4, 7, 10, 17): 'add11',
161             (4, 7, 10, 14): 'add9',
162             (5): 'power chord',
163         }
164         intervals = list(self.other_notes.keys())
165         intervals.sort()
166         intervals_set = set(intervals)
167         for interval in reversed(list(list_utils.powerset(intervals))):
168             name = names.get(tuple(interval), None)
169             if name:
170                 for x in interval:
171                     intervals_set.remove(x)
172                 break
173
174         if not name:
175             name = 'An unknown chord'
176
177         for semitone_count, note_name in self.other_notes.items():
178             if 'bass' in note_name:
179                 note = self.note_at_degree(semitone_count)
180                 name += f' with a bass {note}'
181                 if semitone_count in intervals_set:
182                     intervals_set.remove(semitone_count)
183
184         for note in intervals_set:
185             note = self.note_at_degree(note)
186             name += f' add {note}'
187
188         return f'{self.root_note} ({name})'
189
190     def __repr__(self):
191         name = self.describe_chord()
192         ret = f'{name}\n'
193         ret += f'root={self.root_note}\n'
194         for semitone_count, interval_name in self.other_notes.items():
195             note = self.note_at_degree(semitone_count)
196             ret += f'+ {interval_name} ({semitone_count}) => {note}\n'
197         return ret
198
199
200 @decorator_utils.decorate_matching_methods_with(
201     debug_parse,
202     acl=acl.StringWildcardBasedACL(
203         allowed_patterns=[
204             'enter*',
205             'exit*',
206         ],
207         denied_patterns=['enterEveryRule', 'exitEveryRule'],
208         order_to_check_allow_deny=acl.Order.DENY_ALLOW,
209         default_answer=False,
210     ),
211 )
212 class ChordParser(chordsListener):
213     """A class to parse dates expressed in human language."""
214
215     MINOR = 0
216     MAJOR = 1
217     SUSPENDED = 2
218     DIMINISHED = 3
219     AUGMENTED = 4
220     POWER = 5
221
222     def __init__(self) -> None:
223         pass
224
225     @staticmethod
226     def interval_name_to_semitone_count(
227         interval_name: str, root: Optional[str]
228     ) -> Optional[List[int]]:
229         interval_name = interval_name.lower()
230         interval_name = re.sub(r'\s+', '', interval_name)
231         interval_name = re.sub(r'th', '', interval_name)
232         interval_name = re.sub(r'add', '', interval_name)
233         interval_name = re.sub(r'perfect', '', interval_name)
234         logger.debug('Canonicalized interval name: %s', interval_name)
235
236         number = None
237         g = re.search(r'[1-9]+', interval_name)
238         if g:
239             number = int(g.group(0))
240         else:
241             return None
242         logger.debug('Number: %d', number)
243
244         minor = 'min' in interval_name or 'b' in interval_name
245         diminished = 'dim' in interval_name
246         augmented = 'aug' in interval_name or '#' in interval_name
247         base_intervals = {
248             2: 2,
249             3: 4,
250             4: 5,
251             5: 7,
252             6: 9,
253             7: 11,
254             9: 14,
255             10: 16,
256             11: 17,
257             13: 19,
258         }
259
260         base_interval = base_intervals.get(number, None)
261         if base_interval is None:
262             return None
263         logger.debug('Starting base_interval is %d', base_interval)
264
265         if diminished:
266             logger.debug('Diminished...')
267             base_interval -= 2
268         elif minor:
269             logger.debug('Minor...')
270             base_interval -= 1
271         elif augmented:
272             logger.debug('Augmented...')
273             base_interval += 1
274         logger.debug('Returning %d semitones.', base_interval)
275         return base_interval
276
277     def parse(self, chord_string: str) -> Optional[Chord]:
278         chord_string = chord_string.strip()
279         chord_string = re.sub(r'\s+', ' ', chord_string)
280         self._reset()
281         listener = RaisingErrorListener()
282         input_stream = antlr4.InputStream(chord_string)
283         lexer = chordsLexer(input_stream)
284         lexer.removeErrorListeners()
285         lexer.addErrorListener(listener)
286         stream = antlr4.CommonTokenStream(lexer)
287         parser = chordsParser(stream)
288         parser.removeErrorListeners()
289         parser.addErrorListener(listener)
290         tree = parser.parse()
291         walker = antlr4.ParseTreeWalker()
292         walker.walk(self, tree)
293         return self.chord
294
295     def _reset(self) -> None:
296         self.chord = None
297         self.rootNote = None
298         self.susNote = None
299         self.bassNote = None
300         self.chordType = ChordParser.MAJOR
301         self.addedNotes = []
302
303     # -- overridden methods invoked by parse walk.  Note: not part of the class'
304     # public API(!!) --
305
306     def visitErrorNode(self, node: antlr4.ErrorNode) -> None:
307         pass
308
309     def visitTerminal(self, node: antlr4.TerminalNode) -> None:
310         pass
311
312     def exitParse(self, ctx: chordsParser.ParseContext) -> None:
313         """Populate self.chord"""
314
315         chord_types_with_perfect_5th = set(
316             [
317                 ChordParser.MAJOR,
318                 ChordParser.MINOR,
319                 ChordParser.SUSPENDED,
320                 ChordParser.POWER,
321             ]
322         )
323         if self.chordType in chord_types_with_perfect_5th:
324             if self.chordType == ChordParser.MAJOR:
325                 logger.debug('Major chord.')
326                 self.addedNotes.append('maj3')
327             elif self.chordType == ChordParser.MINOR:
328                 logger.debug('Minor chord.')
329                 self.addedNotes.append('min3')
330             elif self.chordType == ChordParser.SUSPENDED:
331                 if self.susNote == 2:
332                     logger.debug('sus2 chord.')
333                     self.addedNotes.append('maj2')
334                 elif self.susNote == 4:
335                     logger.debug('sus4 chord.')
336                     self.addedNotes.append('perfect4')
337             elif self.chordType == ChordParser.POWER:
338                 logger.debug('Power chord.')
339             self.addedNotes.append('perfect5th')
340         elif self.chordType == ChordParser.DIMINISHED:
341             logger.debug('Diminished chord.')
342             self.addedNotes.append('min3')
343             self.addedNotes.append('dim5')
344         elif self.chordType == ChordParser.AUGMENTED:
345             logger.debug('Augmented chord.')
346             self.addedNotes.append('maj3')
347             self.addedNotes.append('aug5')
348
349         other_notes: Set[int] = set()
350         for expression in self.addedNotes:
351             semitone_count = ChordParser.interval_name_to_semitone_count(expression, self.rootNote)
352             other_notes.add(semitone_count)
353             if semitone_count in (14, 17, 19):
354                 other_notes.add(10)
355         self.chord = Chord(self.rootNote, other_notes)
356         if self.bassNote:
357             self.chord.add_bass(self.bassNote)
358
359     def exitRootNote(self, ctx: chordsParser.RootNoteContext):
360         self.rootNote = ctx.NOTE().__str__().upper()
361         logger.debug('Root note is "%s"', self.rootNote)
362
363     def exitOverBassNoteExpr(self, ctx: chordsParser.OverBassNoteExprContext):
364         self.bassNote = ctx.NOTE().__str__().upper()
365         logger.debug('Bass note is "%s"', self.bassNote)
366
367     def exitPowerChordExpr(self, ctx: chordsParser.PowerChordExprContext):
368         self.chordType = ChordParser.POWER
369         logger.debug('Power chord')
370
371     def exitMajExpr(self, ctx: chordsParser.MajExprContext):
372         self.chordType = ChordParser.MAJOR
373         logger.debug('Major')
374
375     def exitMinExpr(self, ctx: chordsParser.MinExprContext):
376         self.chordType = ChordParser.MINOR
377         logger.debug('Minor')
378
379     def exitSusExpr(self, ctx: chordsParser.SusExprContext):
380         self.chordType = ChordParser.SUSPENDED
381         logger.debug('Suspended')
382         if '2' in ctx.getText():
383             self.susNote = 2
384         elif '4' in ctx.getText():
385             self.susNote = 4
386
387     def exitDiminishedExpr(self, ctx: chordsParser.DiminishedExprContext):
388         self.chordType = ChordParser.DIMINISHED
389         logger.debug('Diminished')
390
391     def exitAugmentedExpr(self, ctx: chordsParser.AugmentedExprContext):
392         self.chordType = ChordParser.AUGMENTED
393         logger.debug('Augmented')
394
395     def exitAddNotesExpr(self, ctx: chordsParser.AddNotesExprContext):
396         if ctx.SIX():
397             self.addedNotes.append('maj6')
398         if ctx.SEVEN():
399             self.addedNotes.append('min7')
400         if ctx.MAJ_SEVEN():
401             self.addedNotes.append('maj7')
402         if ctx.MIN_SEVEN():
403             self.addedNotes.append('min7')
404             self.chordType = ChordParser.MINOR
405         if ctx.NINE():
406             self.addedNotes.append('maj9')
407         if ctx.ELEVEN():
408             self.addedNotes.append('min7')
409             self.addedNotes.append('maj11')
410
411     def exitExtensionExpr(self, ctx: chordsParser.ExtensionExprContext):
412         self.addedNotes.append(ctx.getText())
413
414
415 @bootstrap.initialize
416 def main() -> None:
417     parser = ChordParser()
418     for line in sys.stdin:
419         line = line.strip()
420         line = re.sub(r"#.*$", "", line)
421         if re.match(r"^ *$", line) is not None:
422             continue
423         try:
424             chord = parser.parse(line)
425             print(chord)
426         except Exception as e:
427             logger.exception(e)
428             print("Unrecognized.")
429     sys.exit(0)
430
431
432 if __name__ == "__main__":
433     main()