Chord name parsing.
[python_utils.git] / music / chords.py
1 #!/usr/bin/env python3
2 # type: ignore
3 # pylint: disable=W0201
4 # pylint: disable=R0904
5
6 # © Copyright 2021-2022, Scott Gasch
7
8 """Parse music chords; work in progress..."""
9
10 import functools
11 import itertools
12 import logging
13 import re
14 import sys
15 from typing import Any, Callable, Dict, Iterator, List, Optional
16
17 import antlr4  # type: ignore
18
19 import acl
20 import bootstrap
21 import decorator_utils
22 from music.chordsLexer import chordsLexer  # type: ignore
23 from music.chordsListener import chordsListener  # type: ignore
24 from music.chordsParser import chordsParser  # type: ignore
25
26 logger = logging.getLogger(__name__)
27 notes = ['A', 'A#', 'B', 'C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#']
28
29
30 def generate_scale(starting_note: str) -> Iterator[str]:
31     starting_note = starting_note.upper()
32     start = 0
33     while notes[start] != starting_note:
34         start += 1
35     while True:
36         if start >= len(notes):
37             start = 0
38         yield notes[start]
39         start += 1
40
41
42 def degree_of_note(root_note: str, target_note: str) -> Optional[int]:
43     root_note = root_note.upper()
44     target_note = target_note.upper()
45     for degree, note in enumerate(itertools.islice(generate_scale(root_note), 24)):
46         print(f'"{target_note}", "{note}", {degree}')
47         if note == target_note:
48             return degree - 1
49     return None
50
51
52 def debug_parse(enter_or_exit_f: Callable[[Any, Any], None]):
53     @functools.wraps(enter_or_exit_f)
54     def debug_parse_wrapper(*args, **kwargs):
55         # slf = args[0]
56         ctx = args[1]
57         depth = ctx.depth()
58         logger.debug(
59             '  ' * (depth - 1)
60             + f'Entering {enter_or_exit_f.__name__} ({ctx.invokingState} / {ctx.exception})'
61         )
62         for c in ctx.getChildren():
63             logger.debug('  ' * (depth - 1) + f'{c} {type(c)}')
64         retval = enter_or_exit_f(*args, **kwargs)
65         return retval
66
67     return debug_parse_wrapper
68
69
70 class ParseException(Exception):
71     """An exception thrown during parsing because of unrecognized input."""
72
73     def __init__(self, message: str) -> None:
74         super().__init__()
75         self.message = message
76
77
78 class RaisingErrorListener(antlr4.DiagnosticErrorListener):
79     """An error listener that raises ParseExceptions."""
80
81     def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e):
82         raise ParseException(msg)
83
84     def reportAmbiguity(self, recognizer, dfa, startIndex, stopIndex, exact, ambigAlts, configs):
85         pass
86
87     def reportAttemptingFullContext(
88         self, recognizer, dfa, startIndex, stopIndex, conflictingAlts, configs
89     ):
90         pass
91
92     def reportContextSensitivity(self, recognizer, dfa, startIndex, stopIndex, prediction, configs):
93         pass
94
95
96 class Chord:
97     def __init__(self, root_note: str, other_notes: Dict[str, int]):
98         self.root_note = root_note.upper()
99         self.other_notes = other_notes
100
101
102 @decorator_utils.decorate_matching_methods_with(
103     debug_parse,
104     acl=acl.StringWildcardBasedACL(
105         allowed_patterns=[
106             'enter*',
107             'exit*',
108         ],
109         denied_patterns=['enterEveryRule', 'exitEveryRule'],
110         order_to_check_allow_deny=acl.Order.DENY_ALLOW,
111         default_answer=False,
112     ),
113 )
114 class ChordParser(chordsListener):
115     """A class to parse dates expressed in human language."""
116
117     MINOR = 0
118     MAJOR = 1
119     SUSPENDED = 2
120     DIMINISHED = 3
121     AUGMENTED = 4
122     POWER = 5
123
124     def __init__(self) -> None:
125         pass
126
127     @staticmethod
128     def interval_name_to_semitone_count(
129         interval_name: str, root: Optional[str]
130     ) -> Optional[List[int]]:
131         interval_name = interval_name.lower()
132         interval_name = re.sub(r'\s+', '', interval_name)
133         interval_name = re.sub(r'th', '', interval_name)
134         interval_name = re.sub(r'add', '', interval_name)
135         interval_name = re.sub(r'perfect', '', interval_name)
136         logger.debug('Canonicalized interval name: %s', interval_name)
137
138         number = None
139         g = re.search(r'[1-9]+', interval_name)
140         if g:
141             number = int(g.group(0))
142         else:
143             return None
144         logger.debug('Number: %d', number)
145
146         minor = 'min' in interval_name or 'b' in interval_name
147         diminished = 'dim' in interval_name
148         augmented = 'aug' in interval_name or '#' in interval_name
149         base_intervals = {
150             2: 2,
151             3: 4,
152             4: 5,
153             5: 7,
154             6: 9,
155             7: 11,
156             9: 14,
157             10: 16,
158             11: 17,
159             13: 19,
160         }
161
162         base_interval = base_intervals.get(number, None)
163         if base_interval is None:
164             return None
165         logger.debug('Starting base_interval is %d', base_interval)
166
167         if diminished:
168             logger.debug('Diminished...')
169             base_interval -= 2
170         elif minor:
171             logger.debug('Minor...')
172             base_interval -= 1
173         elif augmented:
174             logger.debug('Augmented...')
175             base_interval += 1
176         logger.debug('Returning %d semitones.', base_interval)
177         return base_interval
178
179     def parse(self, chord_string: str) -> Optional[Chord]:
180         chord_string = chord_string.strip()
181         chord_string = re.sub(r'\s+', ' ', chord_string)
182         self._reset()
183         listener = RaisingErrorListener()
184         input_stream = antlr4.InputStream(chord_string)
185         lexer = chordsLexer(input_stream)
186         lexer.removeErrorListeners()
187         lexer.addErrorListener(listener)
188         stream = antlr4.CommonTokenStream(lexer)
189         parser = chordsParser(stream)
190         parser.removeErrorListeners()
191         parser.addErrorListener(listener)
192         tree = parser.parse()
193         walker = antlr4.ParseTreeWalker()
194         walker.walk(self, tree)
195         return self.chord
196
197     def _reset(self) -> None:
198         self.chord = None
199         self.rootNote = None
200         self.susNote = None
201         self.bassNote = None
202         self.chordType = ChordParser.MAJOR
203         self.addedNotes = []
204
205     # -- overridden methods invoked by parse walk.  Note: not part of the class'
206     # public API(!!) --
207
208     def visitErrorNode(self, node: antlr4.ErrorNode) -> None:
209         pass
210
211     def visitTerminal(self, node: antlr4.TerminalNode) -> None:
212         pass
213
214     def exitParse(self, ctx: chordsParser.ParseContext) -> None:
215         """Populate self.chord"""
216         print(f'Root note is a {self.rootNote}')
217         scale = list(itertools.islice(generate_scale(self.rootNote), 24))
218
219         chord_types_with_perfect_5th = set(
220             [
221                 ChordParser.MAJOR,
222                 ChordParser.MINOR,
223                 ChordParser.SUSPENDED,
224                 ChordParser.POWER,
225             ]
226         )
227         if self.chordType in chord_types_with_perfect_5th:
228             if self.chordType == ChordParser.MAJOR:
229                 logger.debug('Major chord.')
230                 self.addedNotes.append('maj3')
231             elif self.chordType == ChordParser.MINOR:
232                 logger.debug('Minor chord.')
233                 self.addedNotes.append('min3')
234             elif self.chordType == ChordParser.SUSPENDED:
235                 if self.susNote == 2:
236                     logger.debug('sus2 chord.')
237                     self.addedNotes.append('maj2')
238                 elif self.susNote == 4:
239                     logger.debug('sus4 chord.')
240                     self.addedNotes.append('perfect4')
241             elif self.chordType == ChordParser.POWER:
242                 logger.debug('Power chord.')
243             self.addedNotes.append('perfect5th')
244         elif self.chordType == ChordParser.DIMINISHED:
245             logger.debug('Diminished chord.')
246             self.addedNotes.append('min3')
247             self.addedNotes.append('dim5')
248         elif self.chordType == ChordParser.AUGMENTED:
249             logger.debug('Augmented chord.')
250             self.addedNotes.append('maj3')
251             self.addedNotes.append('aug5')
252
253         other_notes: Dict[str, int] = {}
254         for expression in self.addedNotes:
255             semitone_count = ChordParser.interval_name_to_semitone_count(expression, self.rootNote)
256             if semitone_count in (14, 17, 19):
257                 other_notes['min7'] = 10
258             other_notes[expression] = semitone_count
259
260         for expression, semitone_count in other_notes.items():
261             note_name = scale[semitone_count]
262             print(f'Contains: {expression} ({semitone_count} semitones) => {note_name}')
263         if self.bassNote:
264             degree = degree_of_note(self.rootNote, self.bassNote)
265             print(f'Add a {self.bassNote} ({degree}) in the bass')
266             other_notes[self.bassNote] = degree
267         self.chord = Chord(self.rootNote, other_notes)
268
269     def exitRootNote(self, ctx: chordsParser.RootNoteContext):
270         self.rootNote = ctx.NOTE().__str__().upper()
271         logger.debug('Root note is "%s"', self.rootNote)
272
273     def exitOverBassNoteExpr(self, ctx: chordsParser.OverBassNoteExprContext):
274         self.bassNote = ctx.NOTE().__str__().upper()
275         logger.debug('Bass note is "%s"', self.bassNote)
276
277     def exitPowerChordExpr(self, ctx: chordsParser.PowerChordExprContext):
278         self.chordType = ChordParser.POWER
279         logger.debug('Power chord')
280
281     def exitMajExpr(self, ctx: chordsParser.MajExprContext):
282         self.chordType = ChordParser.MAJOR
283         logger.debug('Major')
284
285     def exitMinExpr(self, ctx: chordsParser.MinExprContext):
286         self.chordType = ChordParser.MINOR
287         logger.debug('Minor')
288
289     def exitSusExpr(self, ctx: chordsParser.SusExprContext):
290         self.chordType = ChordParser.SUSPENDED
291         logger.debug('Suspended')
292         if '2' in ctx.getText():
293             self.susNote = 2
294         elif '4' in ctx.getText():
295             self.susNote = 4
296
297     def exitDiminishedExpr(self, ctx: chordsParser.DiminishedExprContext):
298         self.chordType = ChordParser.DIMINISHED
299         logger.debug('Diminished')
300
301     def exitAugmentedExpr(self, ctx: chordsParser.AugmentedExprContext):
302         self.chordType = ChordParser.AUGMENTED
303         logger.debug('Augmented')
304
305     def exitAddNotesExpr(self, ctx: chordsParser.AddNotesExprContext):
306         if ctx.SIX():
307             self.addedNotes.append('maj6')
308         if ctx.SEVEN():
309             self.addedNotes.append('min7')
310         if ctx.MAJ_SEVEN():
311             self.addedNotes.append('maj7')
312         if ctx.ADD_NINE():
313             self.addedNotes.append('maj9')
314
315     def exitExtensionExpr(self, ctx: chordsParser.ExtensionExprContext):
316         self.addedNotes.append(ctx.getText())
317
318
319 @bootstrap.initialize
320 def main() -> None:
321     parser = ChordParser()
322     for line in sys.stdin:
323         line = line.strip()
324         line = re.sub(r"#.*$", "", line)
325         if re.match(r"^ *$", line) is not None:
326             continue
327         try:
328             chord = parser.parse(line)
329             print(chord)
330         except Exception as e:
331             logger.exception(e)
332             print("Unrecognized.")
333     sys.exit(0)
334
335
336 if __name__ == "__main__":
337     main()