File size: 4,075 Bytes
4ecf8d9
530d98b
 
fdb5dd9
530d98b
 
 
 
 
a5b6307
530d98b
fdb5dd9
530d98b
 
 
 
 
 
 
 
 
fdb5dd9
 
530d98b
 
fdb5dd9
530d98b
 
fdb5dd9
 
 
 
 
 
9045e66
fdb5dd9
530d98b
 
fdb5dd9
530d98b
 
 
 
 
 
 
 
 
 
849c4b7
530d98b
 
fdb5dd9
 
530d98b
fdb5dd9
530d98b
 
fdb5dd9
530d98b
a5b6307
fdb5dd9
530d98b
 
 
fdb5dd9
19d1bcf
530d98b
fdb5dd9
530d98b
fdb5dd9
19d1bcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77f4643
19d1bcf
 
 
 
 
 
 
 
1912310
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
from typing import Dict, List
import torch
torch.backends.cudnn.enabled = False
import json
import pickle
from pathlib import Path
from utils import Vocab
from model import SeqClassifier
import re

# Set model parameters
max_len = 128
hidden_size = 256
num_layers = 2
dropout = 0.1
bidirectional = True

device = "cpu"
ckpt_dir = Path("./ckpt/intent/")
cache_dir = Path("./cache/intent/")

# Load vocabulary and intent index mapping
with open(cache_dir / "vocab.pkl", "rb") as f:
    vocab: Vocab = pickle.load(f)

intent_idx_path = cache_dir / "intent2idx.json"
intent2idx: Dict[str, int] = json.loads(intent_idx_path.read_text())
__idx2label = {idx: intent for intent, idx in intent2idx.items()}

def idx2label(idx: int):
    return __idx2label[idx]

# Set embedding layer size
embeddings_size = (5621, 300)
embeddings = torch.empty(embeddings_size)
embeddings.to(device)

# Load the best model
best_model = SeqClassifier(
    embeddings=embeddings,
    hidden_size=hidden_size,
    num_layers=num_layers,
    dropout=dropout,
    bidirectional=bidirectional,
    num_class=len(intent2idx)
).to(device)

# Define the path to the checkpoint file
ckpt_path = ckpt_dir / "intent_checkpoint.pth"

# Load the model's weights
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
best_model.load_state_dict(checkpoint['model_state_dict'])

# Set the model to evaluation mode
best_model.eval()

# Processing function to convert text to embedding indices
def collate_fn(texts: str) -> torch.tensor:
    texts = re.findall(r"\w+|[^\w\s]", texts)
    encoded_texts = vocab.encode_batch([[text for text in texts]], to_len=max_len)[0]
    encoded_text = torch.tensor(encoded_texts)
    return encoded_text

# Classification function
def classifier(text):
    encoded_text = collate_fn(text).to(device)
    output = best_model(encoded_text)
    Predicted_class = torch.argmax(output).item()
    prediction = idx2label(Predicted_class)
    return prediction

import gradio as gr
from gradio.components import Textbox
import random

def random_sample():
        random_number = random.randint(0, len(examples) - 1)
        return examples[random_number]

examples=[
        "what are some fun things i can partake in in atlanta",
        "how do i make pumpkin pie",
        "what's the currency conversion between rubles and pounds",
        "please set an alarm for mid day",
        "how many hours will it take to get to my destination",
        "so i made a fraudulent transaction",
        "tell lydia and laura where i am located",
        "i want you to talk more quickly",
        "what's the deal with my health care",
        "What's the exchange rate for rubles to pounds",
        "How long will it take to reach my destination",
        "I suspect a fraudulent transaction on my account",
        "Inform Lydia and Laura of my current location",
        "I'd like you to speak faster",
        "Can you provide information about my health care",
        "Give me the details on my health insurance",
        "What's the local time now",
        "Find a recipe for chocolate chip cookies",
        "Check my credit card balance",
        "Translate 'Hello' to French",
        "Recommend a good restaurant nearby",
]
title="Text Intent Classification"
description="""
# Text Intent Classification
This demo uses a model to classify text into different intents or categories. Enter a text and see the classification result.
"""

with gr.Blocks(theme=gr.themes.Soft(), title="Question Answering") as demo:
        gr.Markdown(description)

        with gr.Row():
                C_input = gr.Textbox(lines=3, label="Context paragraph", placeholder="Please enter text")
                A_output = Textbox(lines=3, label="Category")
        with gr.Row():
                random_button = gr.Button("Random")
                classifier_button = gr.Button("classifier")

        random_button.click(random_sample, inputs=None, outputs=C_input)
        classifier_button.click(classifier, inputs=C_input, outputs=A_output)

demo.launch(share=True)