|
import re |
|
import unicodedata |
|
import nltk |
|
from nltk import WordNetLemmatizer |
|
from datasets import Dataset |
|
from transformers import AutoTokenizer |
|
from transformers import AutoModelForSequenceClassification |
|
from transformers import XLMRobertaForSequenceClassification |
|
from transformers import Trainer |
|
import gradio as gr |
|
|
|
def preprocess_text(text: str) -> str: |
|
""" |
|
Preprocesses the input text by removing or replacing specific patterns. |
|
|
|
Args: |
|
text (str): The input text to be preprocessed. |
|
|
|
Returns: |
|
str: The preprocessed text with URLs, mentions, hashtags, emojis, |
|
special characters removed, 'and' replaced, and extra spaces trimmed. |
|
""" |
|
|
|
URL_PATTERN_STR = r"""(?i)((?:https?:(?:/{1,3}|[a-z0-9%])|[a-z0-9.\-]+[.](?:com|net|org|edu|gov|mil|aero|asia|biz|cat|coop|info |
|
|int|jobs|mobi|museum|name|post|pro|tel|travel|xxx|ac|ad|ae|af|ag|ai|al|am|an|ao|aq|ar|as|at|au|aw|ax|az|ba| |
|
bb|bd|be|bf|bg|bh|bi|bj|bm|bn|bo|br|bs|bt|bv|bw|by|bz|ca|cc|cd|cf|cg|ch|ci|ck|cl|cm|cn|co|cr|cs|cu|cv|cx|cy| |
|
cz|dd|de|dj|dk|dm|do|dz|ec|ee|eg|eh|er|es|et|eu|fi|fj|fk|fm|fo|fr|ga|gb|gd|ge|gf|gg|gh|gi|gl|gm|gn|gp|gq|gr| |
|
gs|gt|gu|gw|gy|hk|hm|hn|hr|ht|hu|id|ie|il|im|in|io|iq|ir|is|it|je|jm|jo|jp|ke|kg|kh|ki|km|kn|kp|kr|kw|ky|kz| |
|
la|lb|lc|li|lk|lr|ls|lt|lu|lv|ly|ma|mc|md|me|mg|mh|mk|ml|mm|mn|mo|mp|mq|mr|ms|mt|mu|mv|mw|mx|my|mz|na|nc|ne| |
|
nf|ng|ni|nl|no|np|nr|nu|nz|om|pa|pe|pf|pg|ph|pk|pl|pm|pn|pr|ps|pt|pw|py|qa|re|ro|rs|ru|rw|sa|sb|sc|sd|se|sg| |
|
sh|si|sj|Ja|sk|sl|sm|sn|so|sr|ss|st|su|sv|sx|sy|sz|tc|td|tf|tg|th|tj|tk|tl|tm|tn|to|tp|tr|tt|tv|tw|tz|ua|ug| |
|
uk|us|uy|uz|va|vc|ve|vg|vi|vn|vu|wf|ws|ye|yt|yu|za|zm|zw)/)(?:[^\s()<>{}\[\]]+|\([^\s()]*?\([^\s()]+\)[^\s()] |
|
*?\)|\([^\s]+?\))+(?:\([^\s()]*?\([^\s()]+\)[^\s()]*?\)|\([^\s]+?\)|[^\s`!()\[\]{};:'\".,<>?«»ββββ])|(?:(?<!@) |
|
[a-z0-9]+(?:[.\-][a-z0-9]+)*[.](?:com|net|org|edu|gov|mil|aero|asia|biz|cat|coop|info|int|jobs|mobi|museum|name |
|
|post|pro|tel|travel|xxx|ac|ad|ae|af|ag|ai|al|am|an|ao|aq|ar|as|at|au|aw|ax|az|ba|bb|bd|be|bf|bg|bh|bi|bj|bm|bn |
|
|bo|br|bs|bt|bv|bw|by|bz|ca|cc|cd|cf|cg|ch|ci|ck|cl|cm|cn|co|cr|cs|cu|cv|cx|cy|cz|dd|de|dj|dk|dm|do|dz|ec|ee|eg |
|
|eh|er|es|et|eu|fi|fj|fk|fm|fo|fr|ga|gb|gd|ge|gf|gg|gh|gi|gl|gm|gn|gp|gq|gr|gs|gt|gu|gw|gy|hk|hm|hn|hr|ht|hu|id |
|
|ie|il|im|in|io|iq|ir|is|it|je|jm|jo|jp|ke|kg|kh|ki|km|kn|kp|kr|kw|ky|kz|la|lb|lc|li|lk|lr|ls|lt|lu|lv|ly|ma|mc| |
|
md|me|mg|mh|mk|ml|mm|mn|mo|mp|mq|mr|ms|mt|mu|mv|mw|mx|my|mz|na|nc|ne|nf|ng|ni|nl|no|np|nr|nu|nz|om|pa|pe|pf|pg| |
|
ph|pk|pl|pm|pn|pr|ps|pt|pw|py|qa|re|ro|rs|ru|rw|sa|sb|sc|sd|se|sg|sh|si|sj|Ja|sk|sl|sm|sn|so|sr|ss|st|su|sv|sx| |
|
sy|sz|tc|td|tf|tg|th|tj|tk|tl|tm|tn|to|tp|tr|tt|tv|tw|tz|ua|ug|uk|us|uy|uz|va|vc|ve|vg|vi|vn|vu|wf|ws|ye|yt|yu| |
|
za|zm|zw)\b/?(?!@)))""" |
|
URL_PATTERN = re.compile(URL_PATTERN_STR, re.IGNORECASE) |
|
HASHTAG_PATTERN = re.compile(r'#\w*') |
|
MENTION_PATTERN = re.compile(r'@\w*') |
|
PUNCT_REPEAT_PATTERN = re.compile(r'([!?.]){2,}') |
|
ELONG_PATTERN = re.compile(r'\b(\S*?)(.)\2{2,}\b') |
|
WORD_PATTERN = re.compile(r'[^\w<>\s]') |
|
|
|
text = re.sub(URL_PATTERN, ' <URL>', text) |
|
|
|
text = re.sub(r"/", " / ", text) |
|
|
|
text = re.sub(MENTION_PATTERN, ' <USER> ', text) |
|
|
|
text = re.sub(r"[-+]?[.\d]*[\d]+[:,.\d]*", " <NUMBER> ", text) |
|
|
|
text = re.sub(HASHTAG_PATTERN, ' <HASHTAG> ', text) |
|
|
|
|
|
text = re.sub(PUNCT_REPEAT_PATTERN, lambda match: f" {match.group(1)} <REPEAT> ", text) |
|
|
|
text = re.sub(ELONG_PATTERN, lambda match: f" {match.group(1)}{match.group(2)} <ELONG> ", text) |
|
|
|
text = text.strip() |
|
|
|
text = re.sub(WORD_PATTERN, ' ', text) |
|
text = text.strip() |
|
|
|
text = ''.join(c for c in unicodedata.normalize('NFKD', text) if not unicodedata.combining(c)) |
|
return text |
|
|
|
def lemmatize_text(text: str) -> str: |
|
""" |
|
Lemmatizes the input text using the WordNet lemmatizer. |
|
|
|
This method attempts to lemmatize each word in the input text. If the WordNet |
|
data is not available, it will download the necessary data and retry. |
|
|
|
Args: |
|
text (str): The input text to be lemmatized. |
|
|
|
Returns: |
|
str: The lemmatized text. |
|
""" |
|
lemmatizer = WordNetLemmatizer() |
|
downloaded = False |
|
while not downloaded: |
|
try: |
|
lemmatizer.lemmatize(text) |
|
downloaded = True |
|
except LookupError: |
|
print("Downloading WordNet...") |
|
nltk.download('wordnet') |
|
return ' '.join([lemmatizer.lemmatize(word) for word in text.split()]) |
|
|
|
def predict(phrase: str, finetuned_model: str): |
|
phrase = preprocess_text(phrase) |
|
phrase = lemmatize_text(phrase) |
|
phrase = phrase.lower() |
|
|
|
|
|
if 'xlm' in finetuned_model.lower(): |
|
tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base') |
|
model = XLMRobertaForSequenceClassification.from_pretrained(finetuned_model) |
|
else: |
|
tokenizer = AutoTokenizer.from_pretrained('cardiffnlp/twitter-roberta-base-hate') |
|
model = AutoModelForSequenceClassification.from_pretrained(finetuned_model) |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
processing_class=tokenizer, |
|
) |
|
|
|
|
|
tokens = tokenizer( |
|
phrase, |
|
return_tensors="pt" |
|
) |
|
|
|
|
|
phrase_dataset = Dataset.from_dict({ |
|
"input_ids": tokens["input_ids"], |
|
"attention_mask": tokens["attention_mask"], |
|
}) |
|
|
|
|
|
pred = trainer.predict(phrase_dataset) |
|
|
|
|
|
sexist = "Sexist" if pred.predictions.argmax() == 1 else "Not sexist" |
|
return sexist |
|
|
|
demo = gr.Interface( |
|
fn=predict, |
|
inputs=[ |
|
gr.Textbox( |
|
label="Phrase", |
|
placeholder="Enter a phrase to check if it is sexist or not.", |
|
info="Enter a phrase to check if it is sexist or not.", |
|
), |
|
gr.Dropdown([ |
|
"MatteoFasulo/twitter-roberta-base-hate_69", |
|
"MatteoFasulo/twitter-roberta-base-hate_1337", |
|
"MatteoFasulo/twitter-roberta-base-hate_42", |
|
"MatteoFasulo/xlm-roberta-base_69", |
|
"MatteoFasulo/xlm-roberta-base_1337", |
|
"MatteoFasulo/xlm-roberta-base_42", |
|
], |
|
label="Model", |
|
info="Choose the model to use for prediction. XLM-RoBERTa models are suitable for multilingual text.", |
|
) |
|
], |
|
outputs=gr.Text( |
|
label="Prediction", |
|
info="The prediction of the model on the input phrase.", |
|
), |
|
title="Sexism Detection", |
|
description="A small demo to check if a phrase is sexist or not using a fine-tuned RoBERTa model on hate speech detection.", |
|
theme="huggingface", |
|
) |
|
|
|
demo.launch() |