Hanna Hjelmeland
Fix
483b69e
raw
history blame
4 kB
__all__ = ['is_flower', 'learn', 'classify_image', 'categories', 'image', 'label', 'examples', 'intf']
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
import torch
model_name = "NbAiLab/nb-bert-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
first_model_path = "models/first_model"
first_model = AutoModelForSequenceClassification.from_pretrained(first_model_path)
second_model_path = "models/second_model"
second_model = AutoModelForSequenceClassification.from_pretrained(second_model_path)
f_30_40_model_path = "models/FEMALE_30_40model"
f_30_40_model = AutoModelForSequenceClassification.from_pretrained(f_30_40_model_path)
f_40_55_model_path = "models/FEMALE_40_55model"
f_40_55_model = AutoModelForSequenceClassification.from_pretrained(f_40_55_model_path)
m_30_40_model_path = "models/MALE_30_40model"
m_30_40_model = AutoModelForSequenceClassification.from_pretrained(m_30_40_model_path)
m_40_55_model_path = "models/MALE_40_55model"
m_40_55_model = AutoModelForSequenceClassification.from_pretrained(m_40_55_model_path)
def classify_text(test_text, selected_model):
categories = ['Kvinner 30-40', 'Kvinner 40-55', 'Menn 30-40', 'Menn 40-55']
if selected_model in ('Model 1', 'Model 2'):
if selected_model == 'Model 1':
model = first_model
elif selected_model == 'Model 2':
model = second_model
else:
raise ValueError("Invalid model selection")
inputs = tokenizer(test_text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)
predicted_class = torch.argmax(probabilities, dim=1).item()
class_labels = model.config.id2label
predicted_label = class_labels[predicted_class]
probabilities = probabilities[0].tolist()
return dict(zip(categories, map(float,probabilities)))
elif selected_model == 'Model 3':
models = [f_30_40_model, f_40_55_model, m_30_40_model, m_40_55_model]
predicted_labels = []
performance_labels = ['Lite god', 'Nokså god', 'God']
inputs = tokenizer(test_text, return_tensors="pt")
for model in models:
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)
predicted_class = torch.argmax(probabilities, dim=1).item()
predicted_performance = performance_labels[predicted_class]
predicted_labels.append(predicted_performance)
return dict(zip(categories, map(float,performance_labels)))
# Cell
label = gr.outputs.Label()
categories = ('Kvinner 30-40', 'Kvinner 40-55', 'Menn 30-40', 'Menn 40-55')
app_title = "Target group classifier"
examples = [["Moren leter etter sønnen i et ihjelbombet leilighetskompleks.", 'Model 1'],
["Fotballstadion tok fyr i helgen", 'Model 2'],
["De første månedene av krigen gikk så som så. Nå har Putin skiftet strategi.", 'Model 1'],
["Title: Disse hadde størst formue i 2022, Text lead: Laksearvingen Gustav Magnar Witzøe økte formuen med nesten 7 milliarder i fjor, og troner nok en gang øverst på listen over Norges rikeste.", "Model 3"],
["Title: Dette er de mest populære navnene i 2022, Text lead: Navnetoppen for 2022 er klar. Se hvilke navn som er mest populære i din kommune.", "Model 3"],
["Title: 2023 er det varmeste året noen gang registrert, Text lead: En ny rapport viser at 2023 er det varmeste året registrert siden man startet målingene. Klimaforsker kaller tallene urovekkende.", "Model 3"]
]
intf = gr.Interface(fn=classify_text, inputs=["text", gr.Dropdown(['Model 1', 'Model 2', 'Model 3'])], outputs=label, examples=examples, title=app_title)
intf.launch(inline=False)