Spaces:
Sleeping
Sleeping
# #using pipeline to predict the input text | |
# import pandas as pd | |
# from transformers import pipeline, AutoTokenizer | |
# import pysbd | |
# #-----------------Outcome Prediction----------------- | |
# def outcome(text): | |
# label_mapping = { | |
# 'delete': [0, 'LABEL_0'], | |
# 'keep': [1, 'LABEL_1'], | |
# 'merge': [2, 'LABEL_2'], | |
# 'no consensus': [3, 'LABEL_3'], | |
# 'speedy keep': [4, 'LABEL_4'], | |
# 'speedy delete': [5, 'LABEL_5'], | |
# 'redirect': [6, 'LABEL_6'], | |
# 'withdrawn': [7, 'LABEL_7'] | |
# } | |
# model_name = "research-dump/roberta-large_deletion_multiclass_complete_final" | |
# tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# model = pipeline("text-classification", model=model_name, return_all_scores=True) | |
# # Tokenize and truncate the text | |
# tokens = tokenizer(text, truncation=True, max_length=512) | |
# truncated_text = tokenizer.decode(tokens['input_ids'], skip_special_tokens=True) | |
# results = model(truncated_text) | |
# res_list = [] | |
# for result in results[0]: | |
# for key, value in label_mapping.items(): | |
# if result['label'] == value[1]: | |
# res_list.append({'sentence': truncated_text, 'outcome': key, 'score': result['score']}) | |
# break | |
# return res_list | |
# #-----------------Stance Prediction----------------- | |
# def extract_response(text, model_name, label_mapping): | |
# tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# pipe = pipeline("text-classification", model=model_name, tokenizer=tokenizer, top_k=None) | |
# tokens = tokenizer(text, truncation=True, max_length=512) | |
# truncated_text = tokenizer.decode(tokens['input_ids'], skip_special_tokens=True) | |
# results = pipe(truncated_text) | |
# final_scores = {key: 0.0 for key in label_mapping} | |
# for result in results[0]: | |
# for key, value in label_mapping.items(): | |
# if result['label'] == f'LABEL_{value}': | |
# final_scores[key] = result['score'] | |
# break | |
# return final_scores | |
# def get_stance(text): | |
# label_mapping = { | |
# 'delete': 0, | |
# 'keep': 1, | |
# 'merge': 2, | |
# 'comment': 3 | |
# } | |
# seg = pysbd.Segmenter(language="en", clean=False) | |
# text_list = seg.segment(text) | |
# model = 'research-dump/bert-large-uncased_wikistance_v1' | |
# res_list = [] | |
# for t in text_list: | |
# res = extract_response(t, model,label_mapping) #, access_token) | |
# highest_key = max(res, key=res.get) | |
# highest_score = res[highest_key] | |
# result = {'sentence':t,'stance': highest_key, 'score': highest_score} | |
# res_list.append(result) | |
# return res_list | |
# #-----------------Policy Prediction----------------- | |
# def get_policy(text): | |
# label_mapping = {'Wikipedia:Notability': 0, | |
# 'Wikipedia:What Wikipedia is not': 1, | |
# 'Wikipedia:Neutral point of view': 2, | |
# 'Wikipedia:Verifiability': 3, | |
# 'Wikipedia:Wikipedia is not a dictionary': 4, | |
# 'Wikipedia:Wikipedia is not for things made up one day': 5, | |
# 'Wikipedia:Criteria for speedy deletion': 6, | |
# 'Wikipedia:Deletion policy': 7, | |
# 'Wikipedia:No original research': 8, | |
# 'Wikipedia:Biographies of living persons': 9, | |
# 'Wikipedia:Arguments to avoid in deletion discussions': 10, | |
# 'Wikipedia:Conflict of interest': 11, | |
# 'Wikipedia:Articles for deletion': 12 | |
# } | |
# seg = pysbd.Segmenter(language="en", clean=False) | |
# text_list = seg.segment(text) | |
# model = 'research-dump/bert-large-uncased_wikistance_policy_v1' | |
# res_list = [] | |
# for t in text_list: | |
# res = extract_response(t, model,label_mapping) | |
# highest_key = max(res, key=res.get) | |
# highest_score = res[highest_key] | |
# result = {'sentence': t, 'policy': highest_key, 'score': highest_score} | |
# res_list.append(result) | |
# return res_list | |
# #-----------------Sentiment Analysis----------------- | |
# def extract_highest_score_label(res): | |
# flat_res = [item for sublist in res for item in sublist] | |
# highest_score_item = max(flat_res, key=lambda x: x['score']) | |
# highest_score_label = highest_score_item['label'] | |
# highest_score_value = highest_score_item['score'] | |
# return highest_score_label, highest_score_value | |
# def get_sentiment(text): | |
# #sentiment analysis | |
# model_name = "cardiffnlp/twitter-roberta-base-sentiment-latest" | |
# tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# model = pipeline("text-classification", model=model_name, top_k= None) | |
# #sentence tokenize the text using pysbd | |
# seg = pysbd.Segmenter(language="en", clean=False) | |
# text_list = seg.segment(text) | |
# res = [] | |
# for t in text_list: | |
# results = model(t) | |
# highest_label, highest_score = extract_highest_score_label(results) | |
# result = {'sentence': t,'sentiment': highest_label, 'score': highest_score} | |
# res.append(result) | |
# return res | |
# #-----------------Toxicity Prediction----------------- | |
# def get_offensive_label(text): | |
# #offensive language detection model | |
# model_name = "cardiffnlp/twitter-roberta-base-offensive" | |
# tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# model = pipeline("text-classification", model=model_name, top_k= None) | |
# #sentence tokenize the text using pysbd | |
# seg = pysbd.Segmenter(language="en", clean=False) | |
# text_list = seg.segment(text) | |
# res = [] | |
# for t in text_list: | |
# results = model(t) | |
# highest_label, highest_score = extract_highest_score_label(results) | |
# result = {'sentence': t,'offensive_label': highest_label, 'score': highest_score} | |
# res.append(result) | |
# return res | |
# #create the anchor function | |
# def predict_text(text, model_name): | |
# if model_name == 'outcome': | |
# return outcome(text) | |
# elif model_name == 'stance': | |
# return get_stance(text) | |
# elif model_name == 'policy': | |
# return get_policy(text) | |
# elif model_name == 'sentiment': | |
# return get_sentiment(text) | |
# elif model_name == 'offensive': | |
# return get_offensive_label(text) | |
# else: | |
# return "Invalid model name" | |
import pandas as pd | |
from transformers import pipeline, AutoTokenizer | |
import pysbd | |
import torch | |
label_mapping_wikipedia_en = { | |
'delete': [0, 'LABEL_0'], | |
'keep': [1, 'LABEL_1'], | |
'merge': [2, 'LABEL_2'], | |
'no consensus': [3, 'LABEL_3'], | |
'speedy keep': [4, 'LABEL_4'], | |
'speedy delete': [5, 'LABEL_5'], | |
'redirect': [6, 'LABEL_6'], | |
'withdrawn': [7, 'LABEL_7'] | |
} | |
label_mapping_es = { | |
'Borrar': [0, 'LABEL_0'], | |
'Mantener': [1, 'LABEL_1'], | |
'Fusionar': [2, 'LABEL_2'], | |
'Otros': [3, 'LABEL_3'] | |
} | |
label_mapping_gr = { | |
'Διαγραφή': [0, 'LABEL_0'], | |
'Δεν υπάρχει συναίνεση': [1, 'LABEL_1'], | |
'Διατήρηση': [2, 'LABEL_2'], | |
'συγχώνευση': [3, 'LABEL_3'] | |
} | |
label_mapping_wikidata_ent = { | |
'delete': [0, 'LABEL_0'], | |
'no_consensus': [1, 'LABEL_1'], | |
'merge': [2, 'LABEL_2'], | |
'keep': [3, 'LABEL_3'], | |
'comment': [4, 'LABEL_4'], | |
'redirect': [5, 'LABEL_5'] | |
} | |
label_mapping_wikidata_prop = { | |
'deleted': [0, 'LABEL_0'], | |
'keep': [1, 'LABEL_1'], | |
'no_consensus': [2, 'LABEL_2'] | |
} | |
label_mapping_wikinews = { | |
'delete': [0, 'LABEL_0'], | |
'no_consensus': [1, 'LABEL_1'], | |
'speedy delete': [2, 'LABEL_2'], | |
'keep': [3, 'LABEL_3'], | |
'redirect': [4, 'LABEL_4'], | |
'comment': [5, 'LABEL_5'], | |
'merge': [6, 'LABEL_6'], | |
'withdrawn': [7, 'LABEL_7'] | |
} | |
label_mapping_wikiquote = { | |
'merge': [0, 'LABEL_0'], | |
'keep': [1, 'LABEL_1'], | |
'no_consensus': [2, 'LABEL_2'], | |
'redirect': [3, 'LABEL_3'], | |
'delete': [4, 'LABEL_4'] | |
} | |
best_models_tasks = { | |
'wikipedia': 'research-dump/roberta-large_deletion_multiclass_complete_final_v2', | |
'wikidata_entity': 'research-dump/roberta-large_wikidata_ent_outcome_prediction_v1', | |
'wikidata_property': 'research-dump/roberta-large_wikidata_prop_outcome_prediction_v1', | |
'wikinews': 'research-dump/all-roberta-large-v1_wikinews_outcome_prediction_v1', | |
'wikiquote': 'research-dump/roberta-large_wikiquote_outcome_prediction_v1' | |
} | |
best_models_langs = { | |
'en': 'research-dump/roberta-large_deletion_multiclass_complete_final_v2', | |
'es': 'research-dump/xlm-roberta-large_deletion_multiclass_es', | |
'gr': 'research-dump/xlm-roberta-large_deletion_multiclass_gr' | |
} | |
#-----------------Outcome Prediction----------------- | |
def outcome(text, lang='en', platform='wikipedia', date='', years=None): | |
if lang == 'en': | |
if platform not in best_models_tasks: | |
raise ValueError(f"For lang='en', platform must be one of {list(best_models_tasks.keys())}") | |
model_name = best_models_tasks[platform] | |
if platform == 'wikipedia': | |
label_mapping = label_mapping_wikipedia_en | |
elif platform == 'wikidata_entity': | |
label_mapping = label_mapping_wikidata_ent | |
elif platform == 'wikidata_property': | |
label_mapping = label_mapping_wikidata_prop | |
elif platform == 'wikinews': | |
label_mapping = label_mapping_wikinews | |
elif platform == 'wikiquote': | |
label_mapping = label_mapping_wikiquote | |
elif lang in ['es', 'gr']: | |
if platform != 'wikipedia': | |
raise ValueError(f"For lang='{lang}', only platform='wikipedia' is supported.") | |
model_name = best_models_langs[lang] | |
label_mapping = label_mapping_es if lang == 'es' else label_mapping_gr | |
else: | |
raise ValueError("Invalid lang. Use 'en', 'es', or 'gr'.") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = pipeline("text-classification", model=model_name, return_all_scores=True, device=device) | |
tokens = tokenizer(text, truncation=True, max_length=512) | |
truncated_text = tokenizer.decode(tokens['input_ids'], skip_special_tokens=True) | |
results = model(truncated_text) | |
res_list = [] | |
for result in results[0]: | |
for key, value in label_mapping.items(): | |
if result['label'] == value[1]: | |
res_list.append({'sentence': truncated_text, 'outcome': key, 'score': result['score']}) | |
break | |
return res_list | |
def extract_response(text, model_name, label_mapping): | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
pipe = pipeline("text-classification", model=model_name, tokenizer=tokenizer, top_k=None) | |
tokens = tokenizer(text, truncation=True, max_length=512) | |
truncated_text = tokenizer.decode(tokens['input_ids'], skip_special_tokens=True) | |
results = pipe(truncated_text) | |
final_scores = {key: 0.0 for key in label_mapping} | |
for result in results[0]: | |
for key, value in label_mapping.items(): | |
if result['label'] == f'LABEL_{value}': | |
final_scores[key] = result['score'] | |
break | |
return final_scores | |
#-----------------Stance Detection----------------- | |
def get_stance(text): | |
label_mapping = { | |
'delete': 0, | |
'keep': 1, | |
'merge': 2, | |
'comment': 3 | |
} | |
seg = pysbd.Segmenter(language="en", clean=False) | |
text_list = seg.segment(text) | |
model = 'research-dump/bert-large-uncased_wikistance_v1' | |
res_list = [] | |
for t in text_list: | |
res = extract_response(t, model,label_mapping) #, access_token) | |
highest_key = max(res, key=res.get) | |
highest_score = res[highest_key] | |
result = {'sentence':t,'stance': highest_key, 'score': highest_score} | |
res_list.append(result) | |
return res_list | |
#-----------------Policy Prediction----------------- | |
def get_policy(text): | |
label_mapping = {'Wikipedia:Notability': 0, | |
'Wikipedia:What Wikipedia is not': 1, | |
'Wikipedia:Neutral point of view': 2, | |
'Wikipedia:Verifiability': 3, | |
'Wikipedia:Wikipedia is not a dictionary': 4, | |
'Wikipedia:Wikipedia is not for things made up one day': 5, | |
'Wikipedia:Criteria for speedy deletion': 6, | |
'Wikipedia:Deletion policy': 7, | |
'Wikipedia:No original research': 8, | |
'Wikipedia:Biographies of living persons': 9, | |
'Wikipedia:Arguments to avoid in deletion discussions': 10, | |
'Wikipedia:Conflict of interest': 11, | |
'Wikipedia:Articles for deletion': 12 | |
} | |
seg = pysbd.Segmenter(language="en", clean=False) | |
text_list = seg.segment(text) | |
model = 'research-dump/bert-large-uncased_wikistance_policy_v1' | |
res_list = [] | |
for t in text_list: | |
res = extract_response(t, model,label_mapping) | |
highest_key = max(res, key=res.get) | |
highest_score = res[highest_key] | |
result = {'sentence': t, 'policy': highest_key, 'score': highest_score} | |
res_list.append(result) | |
return res_list | |
#-----------------Sentiment Analysis----------------- | |
def extract_highest_score_label(res): | |
flat_res = [item for sublist in res for item in sublist] | |
highest_score_item = max(flat_res, key=lambda x: x['score']) | |
highest_score_label = highest_score_item['label'] | |
highest_score_value = highest_score_item['score'] | |
return highest_score_label, highest_score_value | |
def get_sentiment(text): | |
#sentiment analysis | |
model_name = "cardiffnlp/twitter-roberta-base-sentiment-latest" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = pipeline("text-classification", model=model_name, top_k= None) | |
#sentence tokenize the text using pysbd | |
seg = pysbd.Segmenter(language="en", clean=False) | |
text_list = seg.segment(text) | |
res = [] | |
for t in text_list: | |
results = model(t) | |
highest_label, highest_score = extract_highest_score_label(results) | |
result = {'sentence': t,'sentiment': highest_label, 'score': highest_score} | |
res.append(result) | |
return res | |
#-----------------Toxicity Prediction----------------- | |
def get_offensive_label(text): | |
#offensive language detection model | |
model_name = "cardiffnlp/twitter-roberta-base-offensive" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = pipeline("text-classification", model=model_name, top_k= None) | |
#sentence tokenize the text using pysbd | |
seg = pysbd.Segmenter(language="en", clean=False) | |
text_list = seg.segment(text) | |
res = [] | |
for t in text_list: | |
results = model(t) | |
highest_label, highest_score = extract_highest_score_label(results) | |
result = {'sentence': t,'offensive_label': highest_label, 'score': highest_score} | |
res.append(result) | |
return res | |
def predict_text(text, model_name, lang='en', platform='wikipedia', date='', years=None): | |
if model_name == 'outcome': | |
return outcome(text, lang=lang, platform=platform, date=date, years=years) | |
elif model_name == 'stance': | |
return get_stance(text) | |
elif model_name == 'policy': | |
return get_policy(text) | |
elif model_name == 'sentiment': | |
return get_sentiment(text) | |
elif model_name == 'offensive': | |
return get_offensive_label(text) | |
else: | |
return "Invalid model name" | |