Spaces:
Runtime error
Runtime error
File size: 3,336 Bytes
702e96a fba58f1 702e96a fba58f1 702e96a fba58f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import gradio as gr
import json
import pickle
from pathlib import Path
from utils import Vocab
from model import SeqTagger
from dataset import SeqTaggingClsDataset
from typing import Dict
import torch
# Disable cudnn to ensure the model runs on CPU
torch.backends.cudnn.enabled = False
# Define hyperparameters
max_len = 256
hidden_size = 500
num_layers = 2
dropout = 0.2
bidirectional = True
lr = 1e-3
batch_size = 1
device = "cpu"
# Model and data paths
ckpt_dir = Path("./ckpt/slot/")
cache_dir = Path("./cache/slot/")
# Load the vocabulary
with open(cache_dir / "vocab.pkl", "rb") as f:
vocab: Vocab = pickle.load(f)
# Load the tag mapping
tag_idx_path = cache_dir / "tag2idx.json"
tag2idx: Dict[str, int] = json.loads(tag_idx_path.read_text())
idx2tag = {idx: tag for tag, idx in tag2idx.items()}
def _idx2tag(idx: int):
return idx2tag[idx]
# Create the dataset
datasets = SeqTaggingClsDataset({}, vocab, tag2idx, max_len)
# Create an uninitialized tensor with the defined shape
shape = (4117, 300)
embeddings = torch.empty(shape).to(device)
# Create the model
best_model = SeqTagger(
embeddings=embeddings,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
bidirectional=bidirectional,
num_class=len(tag2idx)
).to(device)
# Define the path to the model checkpoint
ckpt_path = ckpt_dir / "slot_checkpoint.pth"
# Load the model weights
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
best_model.load_state_dict(checkpoint['model_state_dict'])
# Set the model to evaluation mode
best_model.eval()
def classify(text: str):
# Tokenize the text
str_text = [str(text.split())]
dic_text = {"tokens": str_text, "tags": [None], "id": ["text-0"]}
encoded_data = datasets.collate_fn(dic_text)
preds = []
mask = encoded_data['encoded_tags']
mask = (mask != -1)
# Use the trained model to predict each data point
for encoded_token in encoded_data['encoded_tokens'].to(device):
encoded_token = encoded_token.reshape(1, encoded_token.shape[0])
outputs = best_model(encoded_token)
outputs = torch.argmax(outputs, dim=1)[mask[0]].tolist()
preds.extend([[_idx2tag(output) for output in outputs]])
text_tags = []
for i, tag in enumerate(preds[0]):
if tag == "O":
text_tags.extend([(text.split()[i], None), (" ", None)])
else:
text_tags.extend([(text.split()[i], tag), (" ", None)])
return text_tags
# Create a Gradio interface
demo = gr.Interface(
classify,
gr.Textbox(placeholder="Please enter a text..."),
gr.HighlightedText(),
interpretation="none",
live=False,
enable_queue=True,
examples=[
["i have three people for august seventh"],
["a table for 2 adults and 4 children please"],
["i have a booking tomorrow for chara conelly at 9pm"],
["me and 4 others will be there at 8:30pm"],
["probably malik belliard has done the booking and it is on in 10 days"],
["i want to book a table for me and my wife tonight at 6 p.m"],
["date 18th of december"]
],
title="Slot Tagging",
description="This is a demo for slot tagging. Enter a sentence, and it will predict and highlight the slots."
)
# Launch the Gradio interface
demo.launch(share=True)
|