Spaces:
Runtime error
Runtime error
import gradio as gr | |
import json | |
import pickle | |
from pathlib import Path | |
from utils import Vocab | |
from model import SeqTagger | |
from dataset import SeqTaggingClsDataset | |
from typing import Dict | |
import torch | |
from gradio.components import Textbox | |
import random | |
# Disable cudnn to ensure the model runs on CPU | |
torch.backends.cudnn.enabled = False | |
# Define hyperparameters | |
max_len = 256 | |
hidden_size = 500 | |
num_layers = 2 | |
dropout = 0.2 | |
bidirectional = True | |
lr = 1e-3 | |
batch_size = 1 | |
device = "cpu" | |
# Model and data paths | |
ckpt_dir = Path("./ckpt/slot/") | |
cache_dir = Path("./cache/slot/") | |
# Load the vocabulary | |
with open(cache_dir / "vocab.pkl", "rb") as f: | |
vocab: Vocab = pickle.load(f) | |
# Load the tag mapping | |
tag_idx_path = cache_dir / "tag2idx.json" | |
tag2idx: Dict[str, int] = json.loads(tag_idx_path.read_text()) | |
idx2tag = {idx: tag for tag, idx in tag2idx.items()} | |
def _idx2tag(idx: int): | |
return idx2tag[idx] | |
# Create the dataset | |
datasets = SeqTaggingClsDataset({}, vocab, tag2idx, max_len) | |
# Create an uninitialized tensor with the defined shape | |
shape = (4117, 300) | |
embeddings = torch.empty(shape).to(device) | |
# Create the model | |
best_model = SeqTagger( | |
embeddings=embeddings, | |
hidden_size=hidden_size, | |
num_layers=num_layers, | |
dropout=dropout, | |
bidirectional=bidirectional, | |
num_class=len(tag2idx) | |
).to(device) | |
# Define the path to the model checkpoint | |
ckpt_path = ckpt_dir / "slot_checkpoint.pth" | |
# Load the model weights | |
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu')) | |
best_model.load_state_dict(checkpoint['model_state_dict']) | |
# Set the model to evaluation mode | |
best_model.eval() | |
def tagging(text: str): | |
# Tokenize the text | |
str_text = [str(text.split())] | |
dic_text = {"tokens": str_text, "tags": [None], "id": ["text-0"]} | |
encoded_data = datasets.collate_fn(dic_text) | |
preds = [] | |
mask = encoded_data['encoded_tags'] | |
mask = (mask != -1) | |
# Use the trained model to predict each data point | |
for encoded_token in encoded_data['encoded_tokens'].to(device): | |
encoded_token = encoded_token.reshape(1, encoded_token.shape[0]) | |
outputs = best_model(encoded_token) | |
outputs = torch.argmax(outputs, dim=1)[mask[0]].tolist() | |
preds.extend([[_idx2tag(output) for output in outputs]]) | |
text_tags = [] | |
for i, tag in enumerate(preds[0]): | |
text_tags.extend([(text.split()[i], tag), (" ", None)]) | |
return text_tags | |
examples=[ | |
"i have three people for august seventh", | |
"a table for 2 adults and 4 children please", | |
"i have a booking tomorrow for chara conelly at 9pm", | |
"me and 4 others will be there at 8:30pm", | |
"probably malik belliard has done the booking and it is on in 10 days", | |
"i want to book a table for me and my wife tonight at 6 p.m", | |
"date 18th of december", | |
"The concert is on September fifteenth", | |
"I need a reservation for a party of eight on Sunday", | |
"Her birthday is on May twenty-third", | |
"We have a meeting at ten a.m. tomorrow", | |
"The conference starts at eight o'clock in the morning", | |
"He booked a flight for February seventh", | |
"There is an event on the twenty-ninth of June", | |
"Please reserve a table for two for this evening", | |
"The project deadline is on March fourth", | |
"We'll have a gathering on the first of July" | |
] | |
def random_sample(): | |
random_number = random.randint(0, len(examples) - 1) | |
return examples[random_number] | |
description=""" | |
# Slot Tagging | |
This is a demo for slot tagging. Enter a sentence, and it will predict and highlight the slots. | |
""" | |
title="Slot Tagging" | |
with gr.Blocks(theme=gr.themes.Soft(), title=title) as demo: | |
gr.Markdown(description) | |
with gr.Row(): | |
C_input = Textbox(lines=3, label="Context", placeholder="Please enter a text...") | |
T_output = gr.HighlightedText(lines=3, label="IOB Tagging") | |
with gr.Row(): | |
random_button = gr.Button("Random") | |
tagging_button = gr.Button("Tagging") | |
random_button.click(random_sample, inputs=None, outputs=C_input) | |
tagging_button.click(tagging, inputs=C_input, outputs=T_output) | |
demo.launch() |