Spaces:
Runtime error
Runtime error
import gradio as gr | |
from typing import Dict, List | |
import torch | |
torch.backends.cudnn.enabled = False | |
import json | |
import pickle | |
from pathlib import Path | |
from utils import Vocab | |
from model import SeqClassifier | |
import re | |
# Set model parameters | |
max_len = 128 | |
hidden_size = 256 | |
num_layers = 2 | |
dropout = 0.1 | |
bidirectional = True | |
device = "cpu" | |
ckpt_dir = Path("./ckpt/intent/") | |
cache_dir = Path("./cache/intent/") | |
# Load vocabulary and intent index mapping | |
with open(cache_dir / "vocab.pkl", "rb") as f: | |
vocab: Vocab = pickle.load(f) | |
intent_idx_path = cache_dir / "intent2idx.json" | |
intent2idx: Dict[str, int] = json.loads(intent_idx_path.read_text()) | |
__idx2label = {idx: intent for intent, idx in intent2idx.items()} | |
def idx2label(idx: int): | |
return __idx2label[idx] | |
# Set embedding layer size | |
embeddings_size = (5621, 300) | |
embeddings = torch.empty(embeddings_size) | |
embeddings.to(device) | |
# Load the best model | |
best_model = SeqClassifier( | |
embeddings=embeddings, | |
hidden_size=hidden_size, | |
num_layers=num_layers, | |
dropout=dropout, | |
bidirectional=bidirectional, | |
num_class=len(intent2idx) | |
).to(device) | |
# Define the path to the checkpoint file | |
ckpt_path = ckpt_dir / "intent_checkpoint.pth" | |
# Load the model's weights | |
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu')) | |
best_model.load_state_dict(checkpoint['model_state_dict']) | |
# Set the model to evaluation mode | |
best_model.eval() | |
# Processing function to convert text to embedding indices | |
def collate_fn(texts: str) -> torch.tensor: | |
texts = re.findall(r"\w+|[^\w\s]", texts) | |
encoded_texts = vocab.encode_batch([[text for text in texts]], to_len=max_len)[0] | |
encoded_text = torch.tensor(encoded_texts) | |
return encoded_text | |
# Classification function | |
def classifier(text): | |
encoded_text = collate_fn(text).to(device) | |
output = best_model(encoded_text) | |
Predicted_class = torch.argmax(output).item() | |
prediction = idx2label(Predicted_class) | |
return prediction | |
import gradio as gr | |
from gradio.components import Textbox | |
import random | |
def random_sample(): | |
random_number = random.randint(0, len(examples) - 1) | |
return examples[random_number] | |
examples=[ | |
"what are some fun things i can partake in in atlanta", | |
"how do i make pumpkin pie", | |
"what's the currency conversion between rubles and pounds", | |
"please set an alarm for mid day", | |
"how many hours will it take to get to my destination", | |
"so i made a fraudulent transaction", | |
"tell lydia and laura where i am located", | |
"i want you to talk more quickly", | |
"what's the deal with my health care", | |
"What's the exchange rate for rubles to pounds", | |
"How long will it take to reach my destination", | |
"I suspect a fraudulent transaction on my account", | |
"Inform Lydia and Laura of my current location", | |
"I'd like you to speak faster", | |
"Can you provide information about my health care", | |
"Give me the details on my health insurance", | |
"What's the local time now", | |
"Find a recipe for chocolate chip cookies", | |
"Check my credit card balance", | |
"Translate 'Hello' to French", | |
"Recommend a good restaurant nearby", | |
] | |
title="Text Intent Classification" | |
description=""" | |
# Text Intent Classification | |
This demo uses a model to classify text into different intents or categories. Enter a text and see the classification result. | |
""" | |
with gr.Blocks(theme=gr.themes.Soft(), title="Question Answering") as demo: | |
gr.Markdown(description) | |
with gr.Row(): | |
C_input = gr.Textbox(lines=3, label="Context paragraph", placeholder="Please enter text") | |
A_output = Textbox(lines=3, label="Category") | |
with gr.Row(): | |
random_button = gr.Button("Random") | |
classifier_button = gr.Button("classifier") | |
random_button.click(random_sample, inputs=None, outputs=C_input) | |
classifier_button.click(classifier, inputs=C_input, outputs=A_output) | |
demo.launch(share=True) |