File size: 2,621 Bytes
4ecf8d9
530d98b
 
fdb5dd9
530d98b
 
 
 
 
 
fdb5dd9
530d98b
 
 
 
 
 
 
 
 
fdb5dd9
 
530d98b
 
fdb5dd9
530d98b
 
fdb5dd9
 
 
 
 
 
 
 
530d98b
 
fdb5dd9
530d98b
 
 
 
 
 
 
 
 
 
 
 
 
fdb5dd9
 
530d98b
fdb5dd9
530d98b
 
fdb5dd9
530d98b
 
fdb5dd9
530d98b
 
 
fdb5dd9
530d98b
 
fdb5dd9
530d98b
fdb5dd9
 
530d98b
fdb5dd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e2ee04
fdb5dd9
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import gradio as gr
from typing import Dict, List
import torch
torch.backends.cudnn.enabled = False
import json
import pickle
from pathlib import Path
from utils import Vocab
from model import SeqClassifier

# Set model parameters
max_len = 128
hidden_size = 256
num_layers = 2
dropout = 0.1
bidirectional = True

device = "cpu"
ckpt_dir = Path("./ckpt/intent/")
cache_dir = Path("./cache/intent/")

# Load vocabulary and intent index mapping
with open(cache_dir / "vocab.pkl", "rb") as f:
    vocab: Vocab = pickle.load(f)

intent_idx_path = cache_dir / "intent2idx.json"
intent2idx: Dict[str, int] = json.loads(intent_idx_path.read_text())
__idx2label = {idx: intent for intent, idx in intent2idx.items()}

def idx2label(idx: int):
    return __idx2label[idx]

# Set embedding layer size
embeddings_size = (6491, 300)
embeddings = torch.empty(embeddings_size)
embeddings.to(device)

# Load the best model
best_model = SeqClassifier(
    embeddings=embeddings,
    hidden_size=hidden_size,
    num_layers=num_layers,
    dropout=dropout,
    bidirectional=bidirectional,
    num_class=len(intent2idx)
).to(device)

# Define the path to the checkpoint file
ckpt_path = ckpt_dir / "model_checkpoint.pth"

# Load the model's weights
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
best_model.load_state_dict(checkpoint['model_state_dict'])

# Set the model to evaluation mode
best_model.eval()

# Processing function to convert text to embedding indices
def collate_fn(texts: str) -> torch.tensor:
    texts = texts.split()
    encoded_texts = vocab.encode_batch([[text for text in texts]], to_len=max_len)[0]
    encoded_text = torch.tensor(encoded_texts)
    return encoded_text

# Classification function
def classify(text):
    encoded_text = collate_fn(text).to(device)
    output = best_model(encoded_text)
    Predicted_class = torch.argmax(output).item()
    prediction = idx2label(Predicted_class)
    return "Category:" + prediction

# Create a Gradio interface
demo = gr.Interface(
    fn=classify,
    inputs=gr.Textbox(placeholder="Please enter a text..."),
    outputs="label",
    interpretation="none",
    live=False,
    enable_queue=True,
    examples=[
        ["please set an alarm for mid day"],
        ["tell lydia and laura where i am located"],
        ["what's the deal with my health care"]
    ],
    title="Text Intent Classification Demo",
    description="This demo uses a model to classify text into different intents or categories. Enter a text and see the classification result.",
    theme="gradio/seafoam"
)

# Launch the Gradio interface
demo.launch(share=True)