File size: 4,002 Bytes
28981b1
2f01d97
e494bde
ad3a545
 
e494bde
2f01d97
ad3a545
 
2f01d97
480db19
 
2f01d97
480db19
 
2f01d97
8219eaa
 
ad3a545
8219eaa
 
ad3a545
8219eaa
 
28981b1
8219eaa
 
ad3a545
8219eaa
ad3a545
8219eaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483b69e
 
8219eaa
 
 
 
 
 
 
 
 
 
483b69e
8219eaa
 
ad3a545
 
 
 
 
 
688d5f2
 
8219eaa
 
 
 
ad3a545
8219eaa
480db19
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
__all__ = ['is_flower', 'learn', 'classify_image', 'categories', 'image', 'label', 'examples', 'intf']

import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
import torch


model_name = "NbAiLab/nb-bert-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)

first_model_path = "models/first_model"
first_model = AutoModelForSequenceClassification.from_pretrained(first_model_path)

second_model_path = "models/second_model"
second_model = AutoModelForSequenceClassification.from_pretrained(second_model_path)

f_30_40_model_path = "models/FEMALE_30_40model"
f_30_40_model = AutoModelForSequenceClassification.from_pretrained(f_30_40_model_path)

f_40_55_model_path = "models/FEMALE_40_55model"
f_40_55_model = AutoModelForSequenceClassification.from_pretrained(f_40_55_model_path)

m_30_40_model_path = "models/MALE_30_40model"
m_30_40_model = AutoModelForSequenceClassification.from_pretrained(m_30_40_model_path)

m_40_55_model_path = "models/MALE_40_55model"
m_40_55_model = AutoModelForSequenceClassification.from_pretrained(m_40_55_model_path)

def classify_text(test_text, selected_model):

    categories = ['Kvinner 30-40', 'Kvinner 40-55', 'Menn 30-40', 'Menn 40-55']
    
    if selected_model in ('Model 1', 'Model 2'):
        if selected_model == 'Model 1':
            model = first_model
        elif selected_model == 'Model 2':
            model = second_model
        else:
            raise ValueError("Invalid model selection")
        inputs = tokenizer(test_text, return_tensors="pt")

        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits
            probabilities = torch.softmax(logits, dim=1)

        predicted_class = torch.argmax(probabilities, dim=1).item()
        class_labels = model.config.id2label
        predicted_label = class_labels[predicted_class]
        probabilities = probabilities[0].tolist()

        return dict(zip(categories, map(float,probabilities)))
    elif selected_model == 'Model 3':
        models = [f_30_40_model, f_40_55_model, m_30_40_model, m_40_55_model]
        predicted_labels = []
        performance_labels = ['Lite god', 'Nokså god', 'God']
        inputs = tokenizer(test_text, return_tensors="pt")

        for model in models:
            with torch.no_grad():
                outputs = model(**inputs)
                logits = outputs.logits
                probabilities = torch.softmax(logits, dim=1)
            
            predicted_class = torch.argmax(probabilities, dim=1).item()
            predicted_performance = performance_labels[predicted_class]
            predicted_labels.append(predicted_performance)

        return dict(zip(categories, map(float,performance_labels)))

# Cell
label = gr.outputs.Label()
categories = ('Kvinner 30-40', 'Kvinner 40-55', 'Menn 30-40', 'Menn 40-55')
app_title = "Target group classifier"

examples = [["Moren leter etter sønnen i et ihjelbombet leilighetskompleks.", 'Model 1'],
            ["Fotballstadion tok fyr i helgen", 'Model 2'],
            ["De første månedene av krigen gikk så som så. Nå har Putin skiftet strategi.", 'Model 1'],
            ["Title: Disse hadde størst formue i 2022, Text lead: Laksearvingen Gustav Magnar Witzøe økte formuen med nesten 7 milliarder i fjor, og troner nok en gang øverst på listen over Norges rikeste.", "Model 3"],
            ["Title: Dette er de mest populære navnene i 2022, Text lead: Navnetoppen for 2022 er klar. Se hvilke navn som er mest populære i din kommune.", "Model 3"],
            ["Title: 2023 er det varmeste året noen gang registrert, Text lead: En ny rapport viser at 2023 er det varmeste året registrert siden man startet målingene. Klimaforsker kaller tallene urovekkende.", "Model 3"]
]
intf = gr.Interface(fn=classify_text, inputs=["text", gr.Dropdown(['Model 1', 'Model 2', 'Model 3'])], outputs=label, examples=examples, title=app_title)
intf.launch(inline=False)