xjlulu's picture
"~"
77f4643
import gradio as gr
from typing import Dict, List
import torch
torch.backends.cudnn.enabled = False
import json
import pickle
from pathlib import Path
from utils import Vocab
from model import SeqClassifier
import re
# Set model parameters
max_len = 128
hidden_size = 256
num_layers = 2
dropout = 0.1
bidirectional = True
device = "cpu"
ckpt_dir = Path("./ckpt/intent/")
cache_dir = Path("./cache/intent/")
# Load vocabulary and intent index mapping
with open(cache_dir / "vocab.pkl", "rb") as f:
vocab: Vocab = pickle.load(f)
intent_idx_path = cache_dir / "intent2idx.json"
intent2idx: Dict[str, int] = json.loads(intent_idx_path.read_text())
__idx2label = {idx: intent for intent, idx in intent2idx.items()}
def idx2label(idx: int):
return __idx2label[idx]
# Set embedding layer size
embeddings_size = (5621, 300)
embeddings = torch.empty(embeddings_size)
embeddings.to(device)
# Load the best model
best_model = SeqClassifier(
embeddings=embeddings,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
bidirectional=bidirectional,
num_class=len(intent2idx)
).to(device)
# Define the path to the checkpoint file
ckpt_path = ckpt_dir / "intent_checkpoint.pth"
# Load the model's weights
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
best_model.load_state_dict(checkpoint['model_state_dict'])
# Set the model to evaluation mode
best_model.eval()
# Processing function to convert text to embedding indices
def collate_fn(texts: str) -> torch.tensor:
texts = re.findall(r"\w+|[^\w\s]", texts)
encoded_texts = vocab.encode_batch([[text for text in texts]], to_len=max_len)[0]
encoded_text = torch.tensor(encoded_texts)
return encoded_text
# Classification function
def classifier(text):
encoded_text = collate_fn(text).to(device)
output = best_model(encoded_text)
Predicted_class = torch.argmax(output).item()
prediction = idx2label(Predicted_class)
return prediction
import gradio as gr
from gradio.components import Textbox
import random
def random_sample():
random_number = random.randint(0, len(examples) - 1)
return examples[random_number]
examples=[
"what are some fun things i can partake in in atlanta",
"how do i make pumpkin pie",
"what's the currency conversion between rubles and pounds",
"please set an alarm for mid day",
"how many hours will it take to get to my destination",
"so i made a fraudulent transaction",
"tell lydia and laura where i am located",
"i want you to talk more quickly",
"what's the deal with my health care",
"What's the exchange rate for rubles to pounds",
"How long will it take to reach my destination",
"I suspect a fraudulent transaction on my account",
"Inform Lydia and Laura of my current location",
"I'd like you to speak faster",
"Can you provide information about my health care",
"Give me the details on my health insurance",
"What's the local time now",
"Find a recipe for chocolate chip cookies",
"Check my credit card balance",
"Translate 'Hello' to French",
"Recommend a good restaurant nearby",
]
title="Text Intent Classification"
description="""
# Text Intent Classification
This demo uses a model to classify text into different intents or categories. Enter a text and see the classification result.
"""
with gr.Blocks(theme=gr.themes.Soft(), title="Question Answering") as demo:
gr.Markdown(description)
with gr.Row():
C_input = gr.Textbox(lines=3, label="Context paragraph", placeholder="Please enter text")
A_output = Textbox(lines=3, label="Category")
with gr.Row():
random_button = gr.Button("Random")
classifier_button = gr.Button("classifier")
random_button.click(random_sample, inputs=None, outputs=C_input)
classifier_button.click(classifier, inputs=C_input, outputs=A_output)
demo.launch(share=True)