xjlulu's picture
"wahaha"
8e2ee04
raw
history blame
2.62 kB
import gradio as gr
from typing import Dict, List
import torch
torch.backends.cudnn.enabled = False
import json
import pickle
from pathlib import Path
from utils import Vocab
from model import SeqClassifier
# Set model parameters
max_len = 128
hidden_size = 256
num_layers = 2
dropout = 0.1
bidirectional = True
device = "cpu"
ckpt_dir = Path("./ckpt/intent/")
cache_dir = Path("./cache/intent/")
# Load vocabulary and intent index mapping
with open(cache_dir / "vocab.pkl", "rb") as f:
vocab: Vocab = pickle.load(f)
intent_idx_path = cache_dir / "intent2idx.json"
intent2idx: Dict[str, int] = json.loads(intent_idx_path.read_text())
__idx2label = {idx: intent for intent, idx in intent2idx.items()}
def idx2label(idx: int):
return __idx2label[idx]
# Set embedding layer size
embeddings_size = (6491, 300)
embeddings = torch.empty(embeddings_size)
embeddings.to(device)
# Load the best model
best_model = SeqClassifier(
embeddings=embeddings,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
bidirectional=bidirectional,
num_class=len(intent2idx)
).to(device)
# Define the path to the checkpoint file
ckpt_path = ckpt_dir / "model_checkpoint.pth"
# Load the model's weights
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
best_model.load_state_dict(checkpoint['model_state_dict'])
# Set the model to evaluation mode
best_model.eval()
# Processing function to convert text to embedding indices
def collate_fn(texts: str) -> torch.tensor:
texts = texts.split()
encoded_texts = vocab.encode_batch([[text for text in texts]], to_len=max_len)[0]
encoded_text = torch.tensor(encoded_texts)
return encoded_text
# Classification function
def classify(text):
encoded_text = collate_fn(text).to(device)
output = best_model(encoded_text)
Predicted_class = torch.argmax(output).item()
prediction = idx2label(Predicted_class)
return "Category:" + prediction
# Create a Gradio interface
demo = gr.Interface(
fn=classify,
inputs=gr.Textbox(placeholder="Please enter a text..."),
outputs="label",
interpretation="none",
live=False,
enable_queue=True,
examples=[
["please set an alarm for mid day"],
["tell lydia and laura where i am located"],
["what's the deal with my health care"]
],
title="Text Intent Classification Demo",
description="This demo uses a model to classify text into different intents or categories. Enter a text and see the classification result.",
theme="gradio/seafoam"
)
# Launch the Gradio interface
demo.launch(share=True)