import gradio as gr from typing import Dict, List import torch torch.backends.cudnn.enabled = False import json import pickle from pathlib import Path from utils import Vocab from model import SeqClassifier import re # Set model parameters max_len = 128 hidden_size = 256 num_layers = 2 dropout = 0.1 bidirectional = True device = "cpu" ckpt_dir = Path("./ckpt/intent/") cache_dir = Path("./cache/intent/") # Load vocabulary and intent index mapping with open(cache_dir / "vocab.pkl", "rb") as f: vocab: Vocab = pickle.load(f) intent_idx_path = cache_dir / "intent2idx.json" intent2idx: Dict[str, int] = json.loads(intent_idx_path.read_text()) __idx2label = {idx: intent for intent, idx in intent2idx.items()} def idx2label(idx: int): return __idx2label[idx] # Set embedding layer size embeddings_size = (5621, 300) embeddings = torch.empty(embeddings_size) # Load the best model best_model = SeqClassifier( embeddings=embeddings, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, bidirectional=bidirectional, num_class=len(intent2idx) ).to(device) # Define the path to the checkpoint file ckpt_path = ckpt_dir / "intent_checkpoint.pth" # Load the model's weights checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu')) best_model.load_state_dict(checkpoint['model_state_dict']) # Set the model to evaluation mode best_model.eval() # Processing function to convert text to embedding indices def collate_fn(texts: str) -> torch.tensor: texts = re.findall(r"\w+|[^\w\s]", texts) encoded_texts = vocab.encode_batch([[text for text in texts]], to_len=max_len)[0] encoded_text = torch.tensor(encoded_texts) return encoded_text # Classification function def classifier(text): encoded_text = collate_fn(text).to(device) output = best_model(encoded_text) Predicted_class = torch.argmax(output).item() prediction = idx2label(Predicted_class) return prediction import gradio as gr from gradio.components import Textbox import random def random_sample(): random_number = random.randint(0, len(examples) - 1) return examples[random_number] examples=[ "what are some fun things i can partake in in atlanta", "how do i make pumpkin pie", "what's the currency conversion between rubles and pounds", "please set an alarm for mid day", "how many hours will it take to get to my destination", "so i made a fraudulent transaction", "tell lydia and laura where i am located", "i want you to talk more quickly", "what's the deal with my health care", "What's the exchange rate for rubles to pounds", "How long will it take to reach my destination", "I suspect a fraudulent transaction on my account", "Inform Lydia and Laura of my current location", "I'd like you to speak faster", "Can you provide information about my health care", "Give me the details on my health insurance", "What's the local time now", "Find a recipe for chocolate chip cookies", "Check my credit card balance", "Translate 'Hello' to French", "Recommend a good restaurant nearby", ] title="Text Intent Classification" description=""" # Text Intent Classification This demo uses a model to classify text into different intents or categories. Enter a text and see the classification result. """ with gr.Blocks(theme=gr.themes.Soft(), title="Question Answering") as demo: gr.Markdown(description) with gr.Row(): C_input = gr.Textbox(lines=3, label="Context paragraph", placeholder="Please enter text") A_output = Textbox(lines=3, label="Category") with gr.Row(): random_button = gr.Button("Random") classifier_button = gr.Button("classifier"), inputs=None, outputs=C_input), inputs=C_input, outputs=A_output) demo.launch(share=True)