Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
from corpy.morphodita import Tokenizer | |
import transformers | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
model_checkpoint = 'ufal/robeczech-base' | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
transformers.logging.set_verbosity(transformers.logging.ERROR) | |
def classify_sentence(sent:str): | |
toksentence = tokenizer(sent,truncation=True,return_tensors="pt") | |
model.eval() | |
with torch.no_grad(): | |
toksentence.to(device) | |
output = model(**toksentence) | |
return F.softmax(output.logits,dim=1).argmax(dim=1) | |
def classify_text(text:str): | |
tokenizer_morphodita = Tokenizer("czech") | |
all = [] | |
for sentence in tokenizer_morphodita.tokenize(text, sents=True): | |
all.append(sentence) | |
sentences = np.array([' '.join(x) for x in all]) | |
annotations = np.array(list(map(classify_sentence,sentences))) | |
return annotations | |
def classify_text_wrapper(text:str): | |
result = classify_text(text) | |
n = len(result) | |
non_biased = np.where(result==0)[0].shape[0] | |
biased = np.where(result==1)[0].shape[0] | |
return {'Non-biased':non_biased/n,'Biased':biased/n} | |
def interpret_bias(text:str): | |
result = classify_text(text) | |
tokenizer_morphodita = Tokenizer("czech") | |
interpretation = [] | |
all = [] | |
for sentence in tokenizer_morphodita.tokenize(text, sents=True): | |
all.append(sentence) | |
sentences = np.array([' '.join(x) for x in all]) | |
for idx,sentence in enumerate(sentences): | |
score = 0 | |
#non biased | |
if result[idx] == 0: | |
score = -1 | |
#biased | |
if result[idx] == 1: | |
score = 1 | |
interpretation.append((sentence, score)) | |
return interpretation | |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) | |
model = AutoModelForSequenceClassification.from_pretrained("sagittariusA/media_bias_classifier_cs") | |
model.eval() | |
label = gr.outputs.Label(num_top_classes=2) | |
inputs = gr.inputs.Textbox(placeholder=None, default="", label=None) | |
app = gr.Interface(fn=classify_text_wrapper,title='Bias classifier',theme='default', | |
inputs="textbox",layout='unaligned', outputs=label, capture_session=True | |
,interpretation=interpret_bias) | |
app.launch(inbrowser=True) | |