slot_tagging / app.py
xjlulu's picture
"~"
e233148
import gradio as gr
import json
import pickle
from pathlib import Path
from utils import Vocab
from model import SeqTagger
from dataset import SeqTaggingClsDataset
from typing import Dict
import torch
from gradio.components import Textbox
import random
# Disable cudnn to ensure the model runs on CPU
torch.backends.cudnn.enabled = False
# Define hyperparameters
max_len = 256
hidden_size = 500
num_layers = 2
dropout = 0.2
bidirectional = True
lr = 1e-3
batch_size = 1
device = "cpu"
# Model and data paths
ckpt_dir = Path("./ckpt/slot/")
cache_dir = Path("./cache/slot/")
# Load the vocabulary
with open(cache_dir / "vocab.pkl", "rb") as f:
vocab: Vocab = pickle.load(f)
# Load the tag mapping
tag_idx_path = cache_dir / "tag2idx.json"
tag2idx: Dict[str, int] = json.loads(tag_idx_path.read_text())
idx2tag = {idx: tag for tag, idx in tag2idx.items()}
def _idx2tag(idx: int):
return idx2tag[idx]
# Create the dataset
datasets = SeqTaggingClsDataset({}, vocab, tag2idx, max_len)
# Create an uninitialized tensor with the defined shape
shape = (4117, 300)
embeddings = torch.empty(shape).to(device)
# Create the model
best_model = SeqTagger(
embeddings=embeddings,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
bidirectional=bidirectional,
num_class=len(tag2idx)
).to(device)
# Define the path to the model checkpoint
ckpt_path = ckpt_dir / "slot_checkpoint.pth"
# Load the model weights
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
best_model.load_state_dict(checkpoint['model_state_dict'])
# Set the model to evaluation mode
best_model.eval()
def tagging(text: str):
# Tokenize the text
str_text = [str(text.split())]
dic_text = {"tokens": str_text, "tags": [None], "id": ["text-0"]}
encoded_data = datasets.collate_fn(dic_text)
preds = []
mask = encoded_data['encoded_tags']
mask = (mask != -1)
# Use the trained model to predict each data point
for encoded_token in encoded_data['encoded_tokens'].to(device):
encoded_token = encoded_token.reshape(1, encoded_token.shape[0])
outputs = best_model(encoded_token)
outputs = torch.argmax(outputs, dim=1)[mask[0]].tolist()
preds.extend([[_idx2tag(output) for output in outputs]])
text_tags = []
for i, tag in enumerate(preds[0]):
text_tags.extend([(text.split()[i], tag), (" ", None)])
return text_tags
examples=[
"i have three people for august seventh",
"a table for 2 adults and 4 children please",
"i have a booking tomorrow for chara conelly at 9pm",
"me and 4 others will be there at 8:30pm",
"probably malik belliard has done the booking and it is on in 10 days",
"i want to book a table for me and my wife tonight at 6 p.m",
"date 18th of december",
"The concert is on September fifteenth",
"I need a reservation for a party of eight on Sunday",
"Her birthday is on May twenty-third",
"We have a meeting at ten a.m. tomorrow",
"The conference starts at eight o'clock in the morning",
"He booked a flight for February seventh",
"There is an event on the twenty-ninth of June",
"Please reserve a table for two for this evening",
"The project deadline is on March fourth",
"We'll have a gathering on the first of July"
]
def random_sample():
random_number = random.randint(0, len(examples) - 1)
return examples[random_number]
description="""
# Slot Tagging
This is a demo for slot tagging. Enter a sentence, and it will predict and highlight the slots.
"""
title="Slot Tagging"
with gr.Blocks(theme=gr.themes.Soft(), title=title) as demo:
gr.Markdown(description)
with gr.Row():
C_input = Textbox(lines=3, label="Context", placeholder="Please enter text")
T_output = gr.HighlightedText(lines=3, label="IOB Tagging")
with gr.Row():
random_button = gr.Button("Random")
tagging_button = gr.Button("Tagging")
random_button.click(random_sample, inputs=None, outputs=C_input)
tagging_button.click(tagging, inputs=C_input, outputs=T_output)
demo.launch()