import gradio as gr import json import pickle from pathlib import Path from utils import Vocab from model import SeqTagger from dataset import SeqTaggingClsDataset from typing import Dict import torch from gradio.components import Textbox import random # Disable cudnn to ensure the model runs on CPU torch.backends.cudnn.enabled = False # Define hyperparameters max_len = 256 hidden_size = 500 num_layers = 2 dropout = 0.2 bidirectional = True lr = 1e-3 batch_size = 1 device = "cpu" # Model and data paths ckpt_dir = Path("./ckpt/slot/") cache_dir = Path("./cache/slot/") # Load the vocabulary with open(cache_dir / "vocab.pkl", "rb") as f: vocab: Vocab = pickle.load(f) # Load the tag mapping tag_idx_path = cache_dir / "tag2idx.json" tag2idx: Dict[str, int] = json.loads(tag_idx_path.read_text()) idx2tag = {idx: tag for tag, idx in tag2idx.items()} def _idx2tag(idx: int): return idx2tag[idx] # Create the dataset datasets = SeqTaggingClsDataset({}, vocab, tag2idx, max_len) # Create an uninitialized tensor with the defined shape shape = (4117, 300) embeddings = torch.empty(shape).to(device) # Create the model best_model = SeqTagger( embeddings=embeddings, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, bidirectional=bidirectional, num_class=len(tag2idx) ).to(device) # Define the path to the model checkpoint ckpt_path = ckpt_dir / "slot_checkpoint.pth" # Load the model weights checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu')) best_model.load_state_dict(checkpoint['model_state_dict']) # Set the model to evaluation mode best_model.eval() def tagging(text: str): # Tokenize the text str_text = [str(text.split())] dic_text = {"tokens": str_text, "tags": [None], "id": ["text-0"]} encoded_data = datasets.collate_fn(dic_text) preds = [] mask = encoded_data['encoded_tags'] mask = (mask != -1) # Use the trained model to predict each data point for encoded_token in encoded_data['encoded_tokens'].to(device): encoded_token = encoded_token.reshape(1, encoded_token.shape[0]) outputs = best_model(encoded_token) outputs = torch.argmax(outputs, dim=1)[mask[0]].tolist() preds.extend([[_idx2tag(output) for output in outputs]]) text_tags = [] for i, tag in enumerate(preds[0]): text_tags.extend([(text.split()[i], tag), (" ", None)]) return text_tags examples=[ "i have three people for august seventh", "a table for 2 adults and 4 children please", "i have a booking tomorrow for chara conelly at 9pm", "me and 4 others will be there at 8:30pm", "probably malik belliard has done the booking and it is on in 10 days", "i want to book a table for me and my wife tonight at 6 p.m", "date 18th of december", "The concert is on September fifteenth", "I need a reservation for a party of eight on Sunday", "Her birthday is on May twenty-third", "We have a meeting at ten a.m. tomorrow", "The conference starts at eight o'clock in the morning", "He booked a flight for February seventh", "There is an event on the twenty-ninth of June", "Please reserve a table for two for this evening", "The project deadline is on March fourth", "We'll have a gathering on the first of July" ] def random_sample(): random_number = random.randint(0, len(examples) - 1) return examples[random_number] description=""" # Slot Tagging This is a demo for slot tagging. Enter a sentence, and it will predict and highlight the slots. """ title="Slot Tagging" with gr.Blocks(theme=gr.themes.Soft(), title=title) as demo: gr.Markdown(description) with gr.Row(): C_input = Textbox(lines=3, label="Context", placeholder="Please enter a text...") T_output = gr.HighlightedText(lines=3, label="IOB Tagging") with gr.Row(): random_button = gr.Button("Random") tagging_button = gr.Button("Tagging") random_button.click(random_sample, inputs=None, outputs=C_input) tagging_button.click(tagging, inputs=C_input, outputs=T_output) demo.launch()