Spaces:
Runtime error
Runtime error
import gradio as gr | |
import json | |
import pickle | |
from pathlib import Path | |
from utils import Vocab | |
from model import SeqTagger | |
from dataset import SeqTaggingClsDataset | |
from typing import Dict | |
import torch | |
# Disable cudnn to ensure the model runs on CPU | |
torch.backends.cudnn.enabled = False | |
# Define hyperparameters | |
max_len = 256 | |
hidden_size = 500 | |
num_layers = 2 | |
dropout = 0.2 | |
bidirectional = True | |
lr = 1e-3 | |
batch_size = 1 | |
device = "cpu" | |
# Model and data paths | |
ckpt_dir = Path("./ckpt/slot/") | |
cache_dir = Path("./cache/slot/") | |
# Load the vocabulary | |
with open(cache_dir / "vocab.pkl", "rb") as f: | |
vocab: Vocab = pickle.load(f) | |
# Load the tag mapping | |
tag_idx_path = cache_dir / "tag2idx.json" | |
tag2idx: Dict[str, int] = json.loads(tag_idx_path.read_text()) | |
idx2tag = {idx: tag for tag, idx in tag2idx.items()} | |
def _idx2tag(idx: int): | |
return idx2tag[idx] | |
# Create the dataset | |
datasets = SeqTaggingClsDataset({}, vocab, tag2idx, max_len) | |
# Create an uninitialized tensor with the defined shape | |
shape = (4117, 300) | |
embeddings = torch.empty(shape).to(device) | |
# Create the model | |
best_model = SeqTagger( | |
embeddings=embeddings, | |
hidden_size=hidden_size, | |
num_layers=num_layers, | |
dropout=dropout, | |
bidirectional=bidirectional, | |
num_class=len(tag2idx) | |
).to(device) | |
# Define the path to the model checkpoint | |
ckpt_path = ckpt_dir / "slot_checkpoint.pth" | |
# Load the model weights | |
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu')) | |
best_model.load_state_dict(checkpoint['model_state_dict']) | |
# Set the model to evaluation mode | |
best_model.eval() | |
def classify(text: str): | |
# Tokenize the text | |
str_text = [str(text.split())] | |
dic_text = {"tokens": str_text, "tags": [None], "id": ["text-0"]} | |
encoded_data = datasets.collate_fn(dic_text) | |
preds = [] | |
mask = encoded_data['encoded_tags'] | |
mask = (mask != -1) | |
# Use the trained model to predict each data point | |
for encoded_token in encoded_data['encoded_tokens'].to(device): | |
encoded_token = encoded_token.reshape(1, encoded_token.shape[0]) | |
outputs = best_model(encoded_token) | |
outputs = torch.argmax(outputs, dim=1)[mask[0]].tolist() | |
preds.extend([[_idx2tag(output) for output in outputs]]) | |
text_tags = [] | |
for i, tag in enumerate(preds[0]): | |
if tag == "O": | |
text_tags.extend([(text.split()[i], None), (" ", None)]) | |
else: | |
text_tags.extend([(text.split()[i], tag), (" ", None)]) | |
return text_tags | |
# Create a Gradio interface | |
demo = gr.Interface( | |
classify, | |
gr.Textbox(placeholder="Please enter a text..."), | |
gr.HighlightedText(), | |
interpretation="none", | |
live=False, | |
enable_queue=True, | |
examples=[ | |
["i have three people for august seventh"], | |
["a table for 2 adults and 4 children please"], | |
["i have a booking tomorrow for chara conelly at 9pm"], | |
["me and 4 others will be there at 8:30pm"], | |
["probably malik belliard has done the booking and it is on in 10 days"], | |
["i want to book a table for me and my wife tonight at 6 p.m"], | |
["date 18th of december"] | |
], | |
title="Slot Tagging", | |
description="This is a demo for slot tagging. Enter a sentence, and it will predict and highlight the slots." | |
) | |
# Launch the Gradio interface | |
demo.launch(share=True) | |