Spaces:
Runtime error
Runtime error
File size: 4,252 Bytes
702e96a fba58f1 0805ae7 702e96a fba58f1 702e96a fba58f1 0805ae7 fba58f1 0805ae7 fba58f1 0805ae7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import gradio as gr
import json
import pickle
from pathlib import Path
from utils import Vocab
from model import SeqTagger
from dataset import SeqTaggingClsDataset
from typing import Dict
import torch
from gradio.components import Textbox
import random
# Disable cudnn to ensure the model runs on CPU
torch.backends.cudnn.enabled = False
# Define hyperparameters
max_len = 256
hidden_size = 500
num_layers = 2
dropout = 0.2
bidirectional = True
lr = 1e-3
batch_size = 1
device = "cpu"
# Model and data paths
ckpt_dir = Path("./ckpt/slot/")
cache_dir = Path("./cache/slot/")
# Load the vocabulary
with open(cache_dir / "vocab.pkl", "rb") as f:
vocab: Vocab = pickle.load(f)
# Load the tag mapping
tag_idx_path = cache_dir / "tag2idx.json"
tag2idx: Dict[str, int] = json.loads(tag_idx_path.read_text())
idx2tag = {idx: tag for tag, idx in tag2idx.items()}
def _idx2tag(idx: int):
return idx2tag[idx]
# Create the dataset
datasets = SeqTaggingClsDataset({}, vocab, tag2idx, max_len)
# Create an uninitialized tensor with the defined shape
shape = (4117, 300)
embeddings = torch.empty(shape).to(device)
# Create the model
best_model = SeqTagger(
embeddings=embeddings,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
bidirectional=bidirectional,
num_class=len(tag2idx)
).to(device)
# Define the path to the model checkpoint
ckpt_path = ckpt_dir / "slot_checkpoint.pth"
# Load the model weights
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
best_model.load_state_dict(checkpoint['model_state_dict'])
# Set the model to evaluation mode
best_model.eval()
def tagging(text: str):
# Tokenize the text
str_text = [str(text.split())]
dic_text = {"tokens": str_text, "tags": [None], "id": ["text-0"]}
encoded_data = datasets.collate_fn(dic_text)
preds = []
mask = encoded_data['encoded_tags']
mask = (mask != -1)
# Use the trained model to predict each data point
for encoded_token in encoded_data['encoded_tokens'].to(device):
encoded_token = encoded_token.reshape(1, encoded_token.shape[0])
outputs = best_model(encoded_token)
outputs = torch.argmax(outputs, dim=1)[mask[0]].tolist()
preds.extend([[_idx2tag(output) for output in outputs]])
text_tags = []
for i, tag in enumerate(preds[0]):
text_tags.extend([(text.split()[i], tag), (" ", None)])
return text_tags
examples=[
"i have three people for august seventh",
"a table for 2 adults and 4 children please",
"i have a booking tomorrow for chara conelly at 9pm",
"me and 4 others will be there at 8:30pm",
"probably malik belliard has done the booking and it is on in 10 days",
"i want to book a table for me and my wife tonight at 6 p.m",
"date 18th of december",
"The concert is on September fifteenth",
"I need a reservation for a party of eight on Sunday",
"Her birthday is on May twenty-third",
"We have a meeting at ten a.m. tomorrow",
"The conference starts at eight o'clock in the morning",
"He booked a flight for February seventh",
"There is an event on the twenty-ninth of June",
"Please reserve a table for two for this evening",
"The project deadline is on March fourth",
"We'll have a gathering on the first of July"
]
def random_sample():
random_number = random.randint(0, len(examples) - 1)
return examples[random_number]
description="""
# Slot Tagging
This is a demo for slot tagging. Enter a sentence, and it will predict and highlight the slots.
"""
title="Slot Tagging"
with gr.Blocks(theme=gr.themes.Soft(), title=title) as demo:
gr.Markdown(description)
with gr.Row():
C_input = Textbox(lines=3, label="Context", placeholder="Please enter a text...")
T_output = gr.HighlightedText(lines=3, label="IOB Tagging")
with gr.Row():
random_button = gr.Button("Random")
tagging_button = gr.Button("Tagging")
random_button.click(random_sample, inputs=None, outputs=C_input)
tagging_button.click(tagging, inputs=C_input, outputs=T_output)
demo.launch() |