slot_tagging / app.py
xjlulu's picture
"good run"
fba58f1
raw
history blame
3.34 kB
import gradio as gr
import json
import pickle
from pathlib import Path
from utils import Vocab
from model import SeqTagger
from dataset import SeqTaggingClsDataset
from typing import Dict
import torch
# Disable cudnn to ensure the model runs on CPU
torch.backends.cudnn.enabled = False
# Define hyperparameters
max_len = 256
hidden_size = 500
num_layers = 2
dropout = 0.2
bidirectional = True
lr = 1e-3
batch_size = 1
device = "cpu"
# Model and data paths
ckpt_dir = Path("./ckpt/slot/")
cache_dir = Path("./cache/slot/")
# Load the vocabulary
with open(cache_dir / "vocab.pkl", "rb") as f:
vocab: Vocab = pickle.load(f)
# Load the tag mapping
tag_idx_path = cache_dir / "tag2idx.json"
tag2idx: Dict[str, int] = json.loads(tag_idx_path.read_text())
idx2tag = {idx: tag for tag, idx in tag2idx.items()}
def _idx2tag(idx: int):
return idx2tag[idx]
# Create the dataset
datasets = SeqTaggingClsDataset({}, vocab, tag2idx, max_len)
# Create an uninitialized tensor with the defined shape
shape = (4117, 300)
embeddings = torch.empty(shape).to(device)
# Create the model
best_model = SeqTagger(
embeddings=embeddings,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
bidirectional=bidirectional,
num_class=len(tag2idx)
).to(device)
# Define the path to the model checkpoint
ckpt_path = ckpt_dir / "slot_checkpoint.pth"
# Load the model weights
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
best_model.load_state_dict(checkpoint['model_state_dict'])
# Set the model to evaluation mode
best_model.eval()
def classify(text: str):
# Tokenize the text
str_text = [str(text.split())]
dic_text = {"tokens": str_text, "tags": [None], "id": ["text-0"]}
encoded_data = datasets.collate_fn(dic_text)
preds = []
mask = encoded_data['encoded_tags']
mask = (mask != -1)
# Use the trained model to predict each data point
for encoded_token in encoded_data['encoded_tokens'].to(device):
encoded_token = encoded_token.reshape(1, encoded_token.shape[0])
outputs = best_model(encoded_token)
outputs = torch.argmax(outputs, dim=1)[mask[0]].tolist()
preds.extend([[_idx2tag(output) for output in outputs]])
text_tags = []
for i, tag in enumerate(preds[0]):
if tag == "O":
text_tags.extend([(text.split()[i], None), (" ", None)])
else:
text_tags.extend([(text.split()[i], tag), (" ", None)])
return text_tags
# Create a Gradio interface
demo = gr.Interface(
classify,
gr.Textbox(placeholder="Please enter a text..."),
gr.HighlightedText(),
interpretation="none",
live=False,
enable_queue=True,
examples=[
["i have three people for august seventh"],
["a table for 2 adults and 4 children please"],
["i have a booking tomorrow for chara conelly at 9pm"],
["me and 4 others will be there at 8:30pm"],
["probably malik belliard has done the booking and it is on in 10 days"],
["i want to book a table for me and my wife tonight at 6 p.m"],
["date 18th of december"]
],
title="Slot Tagging",
description="This is a demo for slot tagging. Enter a sentence, and it will predict and highlight the slots."
)
# Launch the Gradio interface
demo.launch(share=True)