import gradio as gr from typing import Dict, List import torch torch.backends.cudnn.enabled = False import json import pickle from pathlib import Path from utils import Vocab from model import SeqClassifier # Set model parameters max_len = 128 hidden_size = 256 num_layers = 2 dropout = 0.1 bidirectional = True device = "cpu" ckpt_dir = Path("./ckpt/intent/") cache_dir = Path("./cache/intent/") # Load vocabulary and intent index mapping with open(cache_dir / "vocab.pkl", "rb") as f: vocab: Vocab = pickle.load(f) intent_idx_path = cache_dir / "intent2idx.json" intent2idx: Dict[str, int] = json.loads(intent_idx_path.read_text()) __idx2label = {idx: intent for intent, idx in intent2idx.items()} def idx2label(idx: int): return __idx2label[idx] # Set embedding layer size embeddings_size = (6491, 300) embeddings = torch.empty(embeddings_size) embeddings.to(device) # Load the best model best_model = SeqClassifier( embeddings=embeddings, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, bidirectional=bidirectional, num_class=len(intent2idx) ).to(device) # Define the path to the checkpoint file ckpt_path = ckpt_dir / "intent_checkpoint.pth" # Load the model's weights checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu')) best_model.load_state_dict(checkpoint['model_state_dict']) # Set the model to evaluation mode best_model.eval() # Processing function to convert text to embedding indices def collate_fn(texts: str) -> torch.tensor: texts = texts.split() encoded_texts = vocab.encode_batch([[text for text in texts]], to_len=max_len)[0] encoded_text = torch.tensor(encoded_texts) return encoded_text # Classification function def classify(text): encoded_text = collate_fn(text).to(device) output = best_model(encoded_text) Predicted_class = torch.argmax(output).item() prediction = idx2label(Predicted_class) return "Category:" + prediction # Create a Gradio interface demo = gr.Interface( fn=classify, inputs=gr.Textbox(placeholder="Please enter a text..."), outputs="label", interpretation="none", live=False, enable_queue=True, examples=[ ["please set an alarm for mid day"], ["tell lydia and laura where i am located"], ["what's the deal with my health care"] ], title="Text Intent Classification", description="This demo uses a model to classify text into different intents or categories. Enter a text and see the classification result." ) # Launch the Gradio interface demo.launch(share=True)