Spaces:
Runtime error
Runtime error
import gradio as gr | |
from typing import Dict, List | |
import torch | |
torch.backends.cudnn.enabled = False | |
import json | |
import pickle | |
from pathlib import Path | |
from utils import Vocab | |
from model import SeqClassifier | |
# Set model parameters | |
max_len = 128 | |
hidden_size = 256 | |
num_layers = 2 | |
dropout = 0.1 | |
bidirectional = True | |
device = "cpu" | |
ckpt_dir = Path("./ckpt/intent/") | |
cache_dir = Path("./cache/intent/") | |
# Load vocabulary and intent index mapping | |
with open(cache_dir / "vocab.pkl", "rb") as f: | |
vocab: Vocab = pickle.load(f) | |
intent_idx_path = cache_dir / "intent2idx.json" | |
intent2idx: Dict[str, int] = json.loads(intent_idx_path.read_text()) | |
__idx2label = {idx: intent for intent, idx in intent2idx.items()} | |
def idx2label(idx: int): | |
return __idx2label[idx] | |
# Set embedding layer size | |
embeddings_size = (6491, 300) | |
embeddings = torch.empty(embeddings_size) | |
embeddings.to(device) | |
# Load the best model | |
best_model = SeqClassifier( | |
embeddings=embeddings, | |
hidden_size=hidden_size, | |
num_layers=num_layers, | |
dropout=dropout, | |
bidirectional=bidirectional, | |
num_class=len(intent2idx) | |
).to(device) | |
# Define the path to the checkpoint file | |
ckpt_path = ckpt_dir / "intent_checkpoint.pth" | |
# Load the model's weights | |
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu')) | |
best_model.load_state_dict(checkpoint['model_state_dict']) | |
# Set the model to evaluation mode | |
best_model.eval() | |
# Processing function to convert text to embedding indices | |
def collate_fn(texts: str) -> torch.tensor: | |
texts = texts.split() | |
encoded_texts = vocab.encode_batch([[text for text in texts]], to_len=max_len)[0] | |
encoded_text = torch.tensor(encoded_texts) | |
return encoded_text | |
# Classification function | |
def classify(text): | |
encoded_text = collate_fn(text).to(device) | |
output = best_model(encoded_text) | |
Predicted_class = torch.argmax(output).item() | |
prediction = idx2label(Predicted_class) | |
return "Category:" + prediction | |
# Create a Gradio interface | |
demo = gr.Interface( | |
fn=classify, | |
inputs=gr.Textbox(placeholder="Please enter a text..."), | |
outputs="label", | |
interpretation="none", | |
live=False, | |
enable_queue=True, | |
examples=[ | |
["please set an alarm for mid day"], | |
["tell lydia and laura where i am located"], | |
["what's the deal with my health care"] | |
], | |
title="Text Intent Classification", | |
description="This demo uses a model to classify text into different intents or categories. Enter a text and see the classification result." | |
) | |
# Launch the Gradio interface | |
demo.launch(share=True) | |