File size: 4,247 Bytes
702e96a
fba58f1
 
 
 
 
 
 
 
0805ae7
 
702e96a
fba58f1
 
702e96a
fba58f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0805ae7
fba58f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0805ae7
fba58f1
 
0805ae7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e233148
0805ae7
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import gradio as gr
import json
import pickle
from pathlib import Path
from utils import Vocab
from model import SeqTagger
from dataset import SeqTaggingClsDataset
from typing import Dict
import torch
from gradio.components import Textbox
import random

# Disable cudnn to ensure the model runs on CPU
torch.backends.cudnn.enabled = False

# Define hyperparameters
max_len = 256
hidden_size = 500
num_layers = 2
dropout = 0.2
bidirectional = True
lr = 1e-3
batch_size = 1

device = "cpu"

# Model and data paths
ckpt_dir = Path("./ckpt/slot/")
cache_dir = Path("./cache/slot/")

# Load the vocabulary
with open(cache_dir / "vocab.pkl", "rb") as f:
    vocab: Vocab = pickle.load(f)

# Load the tag mapping
tag_idx_path = cache_dir / "tag2idx.json"
tag2idx: Dict[str, int] = json.loads(tag_idx_path.read_text())
idx2tag = {idx: tag for tag, idx in tag2idx.items()}

def _idx2tag(idx: int):
    return idx2tag[idx]

# Create the dataset
datasets = SeqTaggingClsDataset({}, vocab, tag2idx, max_len)

# Create an uninitialized tensor with the defined shape
shape = (4117, 300)
embeddings = torch.empty(shape).to(device)

# Create the model
best_model = SeqTagger(
    embeddings=embeddings,
    hidden_size=hidden_size,
    num_layers=num_layers,
    dropout=dropout,
    bidirectional=bidirectional,
    num_class=len(tag2idx)
).to(device)

# Define the path to the model checkpoint
ckpt_path = ckpt_dir / "slot_checkpoint.pth"

# Load the model weights
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
best_model.load_state_dict(checkpoint['model_state_dict'])

# Set the model to evaluation mode
best_model.eval()

def tagging(text: str):
    # Tokenize the text
    str_text = [str(text.split())]
    dic_text = {"tokens": str_text, "tags": [None], "id": ["text-0"]}

    encoded_data = datasets.collate_fn(dic_text)

    preds = []
    mask = encoded_data['encoded_tags']
    mask = (mask != -1)

    # Use the trained model to predict each data point
    for encoded_token in encoded_data['encoded_tokens'].to(device):
        encoded_token = encoded_token.reshape(1, encoded_token.shape[0])

        outputs = best_model(encoded_token)
        outputs = torch.argmax(outputs, dim=1)[mask[0]].tolist()

        preds.extend([[_idx2tag(output) for output in outputs]])

    text_tags = []
    for i, tag in enumerate(preds[0]):
        text_tags.extend([(text.split()[i], tag), (" ", None)])
    return text_tags

examples=[
        "i have three people for august seventh",
        "a table for 2 adults and 4 children please",
        "i have a booking tomorrow for chara conelly at 9pm",
        "me and 4 others will be there at 8:30pm",
        "probably malik belliard has done the booking and it is on in 10 days",
        "i want to book a table for me and my wife tonight at 6 p.m",
        "date 18th of december",
        "The concert is on September fifteenth",
        "I need a reservation for a party of eight on Sunday",
        "Her birthday is on May twenty-third",
        "We have a meeting at ten a.m. tomorrow",
        "The conference starts at eight o'clock in the morning",
        "He booked a flight for February seventh",
        "There is an event on the twenty-ninth of June",
        "Please reserve a table for two for this evening",
        "The project deadline is on March fourth",
        "We'll have a gathering on the first of July"
]

def random_sample():
        random_number = random.randint(0, len(examples) - 1)
        return examples[random_number]

description="""
# Slot Tagging
This is a demo for slot tagging. Enter a sentence, and it will predict and highlight the slots.
"""

title="Slot Tagging"

with gr.Blocks(theme=gr.themes.Soft(), title=title) as demo:
        gr.Markdown(description)

        with gr.Row():
                C_input = Textbox(lines=3, label="Context", placeholder="Please enter text")
                T_output = gr.HighlightedText(lines=3, label="IOB Tagging")
        with gr.Row():
                random_button = gr.Button("Random")
                tagging_button = gr.Button("Tagging")

        random_button.click(random_sample, inputs=None, outputs=C_input)
        tagging_button.click(tagging, inputs=C_input, outputs=T_output)

demo.launch()