Make profanity filter catch foo/bar where foo and/or bar are bad
[python_utils.git] / profanity_filter.py
index 31577e0fbf1a4a76486dab066fff764aeb5bbc47..3109f166af211d0160aeca81ddf72e526ceaf2d3 100755 (executable)
@@ -2,6 +2,7 @@
 
 import logging
 import random
+import re
 import string
 import sys
 
@@ -347,6 +348,7 @@ class ProfanityFilter(object):
             'poop chute',
             'poopchute',
             'porn',
+            'pron',
             'pornhub',
             'porno',
             'pornographi',
@@ -469,8 +471,25 @@ class ProfanityFilter(object):
         self.stemmer = PorterStemmer()
 
     def _normalize(self, text: str) -> str:
+        """Normalize text.
+
+        >>> _normalize('Tittie5')
+        'titties'
+
+        >>> _normalize('Suck a Dick!')
+        'suck a dick'
+
+        >>> _normalize('fucking a whore')
+        'fuck a whore'
+
+        """
         result = text.lower()
         result = result.replace("_", " ")
+        result = result.replace('0', 'o')
+        result = result.replace('1', 'l')
+        result = result.replace('4', 'a')
+        result = result.replace('5', 's')
+        result = result.replace('3', 'e')
         for x in string.punctuation:
             result = result.replace(x, "")
         chunks = [
@@ -478,8 +497,26 @@ class ProfanityFilter(object):
         ]
         return ' '.join(chunks)
 
+    def tokenize(self, text: str):
+        for x in nltk.word_tokenize(text):
+            for y in re.split('\W+', x):
+                yield y
+
     def contains_bad_word(self, text: str) -> bool:
-        words = nltk.word_tokenize(text)
+        """Returns True if text contains a bad word (or more than one) 
+        and False if no bad words were detected.
+
+        >>> contains_bad_word('fuck you')
+        True
+
+        >>> contains_bad_word('FucK u')
+        True
+
+        >>> contains_bad_word('FuK U')
+        False
+
+        """
+        words = [word for word in self.tokenize(text)]
         for word in words:
             if self.is_bad_word(word):
                 logger.debug(f'"{word}" is profanity')
@@ -489,14 +526,14 @@ class ProfanityFilter(object):
             for bigram in string_utils.ngrams_presplit(words, 2):
                 bigram = ' '.join(bigram)
                 if self.is_bad_word(bigram):
-                    logger.debug('"{bigram}" is profanity')
+                    logger.debug(f'"{bigram}" is profanity')
                     return True
 
         if len(words) > 2:
             for trigram in string_utils.ngrams_presplit(words, 3):
                 trigram = ' '.join(trigram)
                 if self.is_bad_word(trigram):
-                    logger.debug('"{trigram}" is profanity')
+                    logger.debug(f'"{trigram}" is profanity')
                     return True
         return False
 
@@ -507,7 +544,10 @@ class ProfanityFilter(object):
         )
 
     def obscure_bad_words(self, text: str) -> str:
+        """Obscure bad words that are detected by inserting random punctuation
+        characters.
 
+        """
         def obscure(word: str):
             out = ''
             last = ''
@@ -523,7 +563,7 @@ class ProfanityFilter(object):
                             break
             return out
 
-        words = nltk.word_tokenize(text)
+        words = self.tokenize(text)
         words.append('')
         words.append('')
         words.append('')
@@ -550,6 +590,8 @@ class ProfanityFilter(object):
 
 
 def main() -> None:
+    import doctest
+    doctest.testmod()
     pf = ProfanityFilter()
     phrase = ' '.join(sys.argv[1:])
     print(pf.contains_bad_word(phrase))